si+sj中间有一个切割点,我们在t上枚举这个切割点i,即以t[i]作为最后一个字符时求有多少si可以匹配,以t[i+1]作为第一个字符时有多少sj可以匹配
那么对s串正着建一个ac自动机,反着建一个自动机,然后t正反各匹配一次,用sum[]数组记录t[i]作为最后一个字符可以匹配的串数量
注意:求sum数组时,暴力跳fail显然会t,考虑到跳fail是为了统计匹配串的后缀,那么我们在build时,就可以在处理fail指针时就可以把那个fail的end加到now的end上去,这样就避免了暴力跳fail
#include<bits/stdc++.h> using namespace std; #define N 200005 struct Trie{ int nxt[N][26],fail[N],end[N]; int root,L; int newnode(){ memset(nxt[L],-1,sizeof nxt[L]); end[L]=0; return L++; } void init(){ L++; root=newnode(); } void insert(char buf[]){ int len=strlen(buf+1); int now=root; for(int i=1;i<=len;i++){ if(nxt[now][buf[i]-'a']==-1) nxt[now][buf[i]-'a']=newnode(); now=nxt[now][buf[i]-'a']; } end[now]++; } void build(){ queue<int>q; fail[root]=root; for(int i=0;i<26;i++) if(nxt[root][i]==-1) nxt[root][i]=root; else { fail[nxt[root][i]]=root; q.push(nxt[root][i]); } while(q.size()){ int now=q.front(); q.pop(); for(int i=0;i<26;i++) if(nxt[now][i]==-1) nxt[now][i]=nxt[fail[now]][i]; else { fail[nxt[now][i]]=nxt[fail[now]][i]; end[nxt[now][i]]+=end[nxt[fail[now]][i]]; q.push(nxt[now][i]); } } } int sum[N]; int query(char buf[]){ int len=strlen(buf+1); int now=root; for(int i=1;i<=len;i++){ now=nxt[now][buf[i]-'a']; sum[i]+=end[now]; } } }; char buf[N],t[N]; Trie t1,t2; int n; void reserve(char s[]){ int i=1,j=strlen(s+1); while(i<j){ swap(s[i],s[j]); ++i,--j; } } int main(){ t1.init(); t2.init(); scanf("%s%d",t+1,&n); for(int i=1;i<=n;i++){ scanf("%s",buf+1); t1.insert(buf); reserve(buf); t2.insert(buf); } t1.build(); t2.build(); t1.query(t); reserve(t); t2.query(t); int len=strlen(t+1); long long ans=0; for(int i=0;i<len;i++) ans+=(long long)t1.sum[i]*t2.sum[len-i]; cout<<ans<<' '; }