Submit: 378 Solved: 209
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
bbaa
Sample Output
10
HINT
Source
字符串 广义后缀自动机
把两个字符串建到同一个SAM上,分别记录size。如果某个结点同时有两个串的size,那么可以更新答案
注意第23行,如果一个节点已经有了,就不新建结点了。虽然不知道为什么,但是这步似乎是必要的,不加就WA
1 /*by SilverN*/ 2 #include<algorithm> 3 #include<iostream> 4 #include<cstring> 5 #include<cstdio> 6 #include<cmath> 7 #include<vector> 8 #define LL long long 9 using namespace std; 10 const int mxn=800010; 11 int read(){ 12 int x=0,f=1;char ch=getchar(); 13 while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();} 14 while(ch>='0' && ch<='9'){x=x*10+ch-'0';ch=getchar();} 15 return x*f; 16 } 17 struct SAM{ 18 int t[mxn][26],fa[mxn],l[mxn]; 19 int sz[mxn][2],w[mxn],rk[mxn]; 20 int S,cnt; 21 void init(){S=cnt=1;return;} 22 int add(int c,int p,int bl){ 23 if(t[p][c] && l[t[p][c]]==l[p]+1)return t[p][c]; 24 int np=++cnt; 25 l[np]=l[p]+1; 26 // sz[np][bl]++; 27 for(;p && !t[p][c];p=fa[p])t[p][c]=np; 28 if(!p)fa[np]=S; 29 else{ 30 int q=t[p][c]; 31 if(l[q]==l[p]+1){fa[np]=q;} 32 else{ 33 int nq=++cnt;l[nq]=l[p]+1; 34 memcpy(t[nq],t[q],sizeof t[q]); 35 fa[nq]=fa[q]; 36 fa[q]=fa[np]=nq; 37 for(; p&& t[p][c]==q;p=fa[p])t[p][c]=nq; 38 } 39 } 40 // printf("np:%d %d ",np,fa[np]); 41 return np; 42 } 43 void solve(int len){ 44 int i,j; 45 for(i=1;i<=cnt;i++)w[l[i]]++; 46 for(i=1;i<=len;i++)w[i]+=w[i-1]; 47 for(i=1;i<=cnt;i++)rk[w[l[i]]--]=i; 48 LL ans=0; 49 for(int i=cnt;i;i--){ 50 int t=rk[i]; 51 // printf("t:%d %d ",t,fa[t]); 52 sz[fa[t]][0]+=sz[t][0]; 53 sz[fa[t]][1]+=sz[t][1]; 54 } 55 for(i=1;i<=cnt;i++){ 56 ans+=(LL)(l[i]-l[fa[i]])*sz[i][0]*sz[i][1]; 57 } 58 printf("%lld ",ans); 59 return; 60 } 61 }sa; 62 char s1[mxn>>1],s2[mxn>>1]; 63 int main(){ 64 int i,j; 65 sa.init(); 66 scanf("%s%s",s1+1,s2+1); 67 int n1=strlen(s1+1); 68 int n2=strlen(s2+1); 69 int id=sa.S; 70 for(i=1;i<=n1;i++)id=sa.add(s1[i]-'a',id,0),sa.sz[id][0]++; 71 id=sa.S; 72 for(i=1;i<=n2;i++)id=sa.add(s2[i]-'a',id,1),sa.sz[id][1]++; 73 sa.solve(max(n1,n2)); 74 return 0; 75 }