一个串建SAM,一个串在上面跑DP
需要注意,走到当前节点的时候,有可能走的是近路,并不能把当前节点表示的所有子串匹配,这个时候就要记录一下走的步数(类似caioj那题),那些被当前点表示的,长度不超过步数的子串才有资格更新答案。
这个东西我用g来维护
然后他去更新其他人就没有这个限制了,用h表示覆盖的次数,减去f表示直接走到的次数,然后乘上这个点代表的子串数和出现次数,就是其他人更新我的答案
g和这个东西加起来就是答案了
#include<cstdio> #include<iostream> #include<cstring> #include<cstdlib> #include<algorithm> #include<cmath> using namespace std; typedef long long LL; int a[210000],len; struct SAM { int w[30],dep,fail; }ch[410000];int last,cnt; void insert(int dep,int x) { int pre=last,now=++cnt; ch[now].dep=dep; last=now; while(pre!=0&&ch[pre].w[x]==0) ch[pre].w[x]=now, pre=ch[pre].fail; if(pre==0)ch[now].fail=1; else { int nxt=ch[pre].w[x]; if(ch[nxt].dep==ch[pre].dep+1)ch[now].fail=nxt; else { int nnxt=++cnt; ch[nnxt]=ch[nxt]; ch[nnxt].dep=ch[pre].dep+1; ch[nxt].fail=ch[now].fail=nnxt; while(pre!=0&&ch[pre].w[x]==nxt) ch[pre].w[x]=nnxt, pre=ch[pre].fail; } } } //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~init~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ int T,Right[410000];//根到当前节点组成的子串(当前节点管理的子串)出现次数,当前节点管理的子串数=ch[x].dep-ch[ch[x].fail].dep int Rsort[410000],sa[410000]; void GetRight() { memset(Rsort,0,sizeof(Rsort)); for(int i=1;i<=cnt;i++)Rsort[ch[i].dep]++; for(int i=1;i<=len;i++)Rsort[i]+=Rsort[i-1]; for(int i=cnt;i>=1;i--)sa[Rsort[ch[i].dep]--]=i; int now=1; memset(Right,0,sizeof(Right)); for(int i=1;i<=len;i++) now=ch[now].w[a[i]], Right[now]++; for(int i=cnt;i>=1;i--) { int u=sa[i],v=ch[u].fail; Right[v]+=Right[u]; } Right[1]=0; } //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //------------------------------------------------------SAM----------------------------------------------------------------- LL f[410000],g[410000],h[410000]; void solve() { int now=1,L=0; memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); memset(h,0,sizeof(h)); for(int i=1;i<=len;i++) { int x=a[i]; while(now!=1&&ch[now].w[x]==0) now=ch[now].fail, L=ch[now].dep; if(ch[now].w[x]!=0) { L++; now=ch[now].w[x]; f[now]++,h[now]++; g[now]+=Right[now]*(L-ch[ch[now].fail].dep); } } for(int i=cnt;i>=1;i--) { int u=sa[i],v=ch[u].fail; f[v]+=f[u]; } LL ans=0; for(int i=2;i<=cnt;i++)ans+=g[i]+(f[i]-h[i])*Right[i]*(ch[i].dep-ch[ch[i].fail].dep); printf("%lld ",ans); } char ss[210000]; int main() { freopen("a.in","r",stdin); freopen("a.out","w",stdout); scanf("%s",ss+1);len=strlen(ss+1); last=cnt=1; ch[1].dep=0; for(int i=1;i<=len;i++) a[i]=ss[i]-'a'+1, insert(i,a[i]); GetRight(); scanf("%s",ss+1);len=strlen(ss+1); for(int i=1;i<=len;i++)a[i]=ss[i]-'a'+1; solve(); return 0; }