Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
bbaa
Sample Output
10
Solution
很容易可以想到,我们要对字符串$A$建立$SAM$,然后用$B$串在$A$上跑
只要能匹配到这个节点,那么它顺着$fa$指针向上的那一串节点都可以产生贡献。
每个节点的贡献是$|R(s)|*max((min(Max,temp)-Min+1),0)$,$temp$为当前匹配长度
实际上当匹配到某个节点$s$的时候,$now$一定大于等于$fa(s)$
所以上面的节点直接用$Max$更新就可以。需要特判的只是当前的节点
只要能匹配到这个节点,那么它顺着$fa$指针向上的那一串节点都可以产生贡献。
每个节点的贡献是$|R(s)|*max((min(Max,temp)-Min+1),0)$,$temp$为当前匹配长度
实际上当匹配到某个节点$s$的时候,$now$一定大于等于$fa(s)$
所以上面的节点直接用$Max$更新就可以。需要特判的只是当前的节点
Code
1 #include<iostream> 2 #include<cstring> 3 #include<cstdio> 4 #define N (400000+1000) 5 using namespace std; 6 7 char s[N],t[N]; 8 long long Ans; 9 10 struct SAM 11 { 12 int son[N][28],fa[N],step[N],right[N],wt[N],od[N]; 13 int p,q,np,nq,last,cnt; 14 SAM(){last=++cnt;} 15 16 void Insert(int x) 17 { 18 p=last; np=last=++cnt; step[np]=step[p]+1; right[np]=1; 19 while (!son[p][x] && p) son[p][x]=np,p=fa[p]; 20 if (!p) fa[np]=1; 21 else 22 { 23 q=son[p][x]; 24 if (step[p]+1==step[q]) fa[np]=q; 25 else 26 { 27 nq=++cnt; step[nq]=step[p]+1; 28 memcpy(son[nq],son[q],sizeof(son[q])); 29 fa[nq]=fa[q]; fa[q]=fa[np]=nq; 30 while (son[p][x]==q) son[p][x]=nq,p=fa[p]; 31 } 32 } 33 } 34 void Init() 35 { 36 int len=strlen(s); 37 for (int i=1; i<=cnt; ++i) wt[step[i]]++; 38 for (int i=1; i<=len; ++i) wt[i]+=wt[i-1]; 39 for (int i=cnt; i>=1; --i) od[wt[step[i]]--]=i; 40 for (int i=cnt; i>=1; --i) right[fa[od[i]]]+=right[od[i]]; 41 } 42 void Solve(char s[]) 43 { 44 int now=1,len=strlen(s),temp=0; 45 for (int i=0; i<len; ++i) 46 { 47 while (now && !son[now][s[i]-'a']) 48 now=fa[now],temp=step[now]; 49 if (now==0){now=1; continue;} 50 now=son[now][s[i]-'a'],temp++; 51 Ans+=1ll*right[now]*(temp-step[fa[now]]); 52 int t=now; 53 while (fa[t]) t=fa[t],Ans+=1ll*right[t]*(step[t]-step[fa[t]]); 54 } 55 } 56 }SAM; 57 58 int main() 59 { 60 scanf("%s%s",s,t); 61 int len=strlen(s); 62 for (int i=0; i<len; ++i) 63 SAM.Insert(s[i]-'a'); 64 SAM.Init(); 65 SAM.Solve(t); 66 printf("%lld",Ans); 67 }