题目链接:戳我
广义SAM
就是我们把两个串都建立到后缀自动机上面,然后记录一下每个np节点到底在哪个串里面出现了。(因为只有np节点是真正建立出来,有实际意义的,代表前缀的节点)
然后我们建立出来dfs树,用这个endpos类的大小,乘上它在第一个串中出现的次数*第二个串中出现的次数即可。
注意建立第二个串的时候last=1。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define MAXN 800010
using namespace std;
int tot=1,last=1;
int siz[MAXN][3];
char s1[MAXN],s2[MAXN];
long long ans;
vector<int>g[MAXN];
struct Node{int len,ff,ch[26];}t[MAXN];
inline void extend(int c,int num)
{
int p=last,np=++tot;last=np;
t[np].len=t[p].len+1;
siz[np][num]++;
while(p&&!t[p].ch[c]) t[p].ch[c]=np,p=t[p].ff;
if(!p) t[np].ff=1;
else
{
int q=t[p].ch[c];
if(t[q].len==t[p].len+1) t[np].ff=q;
else
{
int nq=++tot;
t[nq]=t[q],t[nq].len=t[p].len+1;
t[np].ff=t[q].ff=nq;
while(p&&t[p].ch[c]==q) t[p].ch[c]=nq,p=t[p].ff;
}
}
}
inline void dfs(int x)
{
int cur_ans=t[x].len-t[t[x].ff].len;
for(int i=0;i<g[x].size();i++)
{
int v=g[x][i];
dfs(v);
siz[x][0]+=siz[v][0],siz[x][1]+=siz[v][1];
}
ans+=1ll*cur_ans*siz[x][0]*siz[x][1];
}
inline void solve()
{
for(int i=1;i<=tot;i++) g[t[i].ff].push_back(i);
dfs(1);
printf("%lld
",ans);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("ce.in","r",stdin);
#endif
scanf("%s%s",s1,s2);
for(int i=0,len=strlen(s1);i<len;i++) extend(s1[i]-'a',0);
last=1;
for(int i=0,len=strlen(s2);i<len;i++) extend(s2[i]-'a',1);
solve();
return 0;
}