绝对是我写过最长的一份代码了.
这个快敲吐了.
通过这道题能 get 到一个套路:
两颗树同时统计信息的题可以考虑在个树上跑边分治,把点扔到另一颗树的虚树上,然后跑虚树DP.
具体地,这道题中我们发现 $LCP$ 长度是反串后缀树 $LCA$ 深度,$LCS$ 是正串后缀树 $LCA$ 深度.
我们建出正反两串后缀树后,将长度大于 K1/K2 的点的深度置为 0,然后跑一个边分+虚树即可.
code:
#include <cstdio> #include <algorithm> #include <cstring> #include <string> #include <vector> #include <map> #define N 200007 #define inf 0x3f3f3f3f #define ull unsigned long long // 代码已写完,人已阵亡. using namespace std; int bug; int K1,K2; ull ans,W; char S[N]; namespace IO { void setIO(string s) { string in=s+".in"; string out=s+".out"; freopen(in.c_str(),"r",stdin); // freopen(out.c_str(),"w",stdout); } }; struct SAM { #define M N<<1 int tot,last; struct Edge { int to,w; Edge(int to=0,int w=0):to(to),w(w){} }; vector<Edge>G[M]; int pre[M],ch[M][26],mx[M],str_sam[M],sam_str[M],depth[M]; void Initialize() { tot=last=1; } void extend(int c) { int np=++tot,p=last; mx[np]=mx[p]+1,last=np; for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np; if(!p) pre[np]=1; else { int q=ch[p][c]; if(mx[q]==mx[p]+1) pre[np]=q; else { int nq=++tot; mx[nq]=mx[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[q])); pre[nq]=pre[q],pre[np]=pre[q]=nq; for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq; } } } void Build_LCP() { int n=strlen(S+1),i,j,p=1; for(i=1;i<=n;++i) { p=ch[p][S[n-i+1]-'a']; sam_str[p]=n-i+1; str_sam[n-i+1]=p; } for(i=2;i<=tot;++i) { if(mx[i]>K1) depth[i]=0; else depth[i]=mx[i]; } for(i=2;i<=tot;++i) G[pre[i]].push_back(Edge(i,depth[i]-depth[pre[i]])); } void Build_LCS() { int n=strlen(S+1),i,j,p=1; for(i=1;i<=n;++i) { p=ch[p][S[i]-'a']; sam_str[p]=i; str_sam[i]=p; } for(i=2;i<=tot;++i) { if(mx[i]>K2) depth[i]=0; else depth[i]=mx[i]; } for(i=2;i<=tot;++i) G[pre[i]].push_back(Edge(i,depth[i]-depth[pre[i]])); } #undef M }lcp,lcs; namespace vir { vector<int>G[N<<2]; vector<int>clr; int t,sta,tot; int is1[N<<2],is2[N<<2]; int dfn[N<<2],dep[N<<2],size[N<<2],son[N<<2],top[N<<2],f[N<<2]; int S[N<<2],val[N<<2],re[N<<2]; ull size1[N<<2],size2[N<<2]; ull sum1[N<<2],sum2[N<<2]; bool cmp(int a,int b) { return dfn[a]<dfn[b]; } void get_dfn(int x,int fa) { dfn[x]=++t; size[x]=1; f[x]=fa; for(int i=0;i<lcp.G[x].size();++i) { int y=lcp.G[x][i].to; if(y==fa) continue; dep[y]=dep[x]+1; get_dfn(y,x); size[x]+=size[y]; if(size[y]>size[son[x]]) son[x]=y; } } void dfs2(int u,int tp) { top[u]=tp; if(son[u]) dfs2(son[u],tp); for(int i=0;i<lcp.G[u].size();++i) { int v=lcp.G[u][i].to; if(v==son[u]||v==f[u]) continue; dfs2(v,v); } } int LCA(int x,int y) { while(top[x]!=top[y]) { dep[top[x]]>dep[top[y]]?x=f[top[x]]:y=f[top[y]]; } return dep[x]<dep[y]?x:y; } void _new(int x,int v,int c) { ++tot; re[tot]=x; val[x]=v; if(c==1) is1[x]=1; else is2[x]=1; } void addvir(int x,int y) { G[x].push_back(y); } void Initialize() { t=0; get_dfn(1,0); dfs2(1,1); } void Insert(int x) { if(sta<=1) { S[++sta]=x; return; } int lca=LCA(S[sta],x); if(lca==S[sta]) S[++sta]=x; else { while(sta>1&&dep[S[sta-1]]>=dep[lca]) addvir(S[sta-1],S[sta]),--sta; if(S[sta]==lca) S[++sta]=x; else { addvir(lca,S[sta]); S[sta]=lca; S[++sta]=x; } } } void Build() { sta=0; sort(re+1,re+1+tot,cmp); if(re[1]!=1) S[++sta]=1; for(int i=1;i<=tot;++i) Insert(re[i]); while(sta>1) addvir(S[sta-1],S[sta]),--sta; } void DP(int x) { clr.push_back(x); for(int i=0;i<G[x].size();++i) { int y=G[x][i]; DP(y); size1[x]+=size1[y]; size2[x]+=size2[y]; sum1[x]+=sum1[y]; sum2[x]+=sum2[y]; } ull tmp=0; ull cntw=0; ull cur=0; for(int i=0;i<G[x].size();++i) { int y=G[x][i]; tmp+=(sum1[x]-sum1[y])*size2[y]; tmp+=(sum2[x]-sum2[y])*size1[y]; cntw+=(size1[x]-size1[y])*size2[y]; } cur+=tmp*lcp.depth[x]; cur-=cntw*W*lcp.depth[x]; if(is1[x]) { cur+=size2[x]*val[x]*lcp.depth[x]; cur+=sum2[x]*lcp.depth[x]; cur-=size2[x]*W*lcp.depth[x]; } if(is2[x]) { cur+=size1[x]*val[x]*lcp.depth[x]; cur+=sum1[x]*lcp.depth[x]; cur-=size1[x]*W*lcp.depth[x]; } ans+=cur/2; size1[x]+=is1[x]; size2[x]+=is2[x]; sum1[x]+=is1[x]*val[x]; sum2[x]+=is2[x]*val[x]; G[x].clear(); } void solve() { Build(); DP(1); for(int i=0;i<clr.size();++i) { int x=clr[i]; val[x]=sum1[x]=sum2[x]=size1[x]=size2[x]=is1[x]=is2[x]=0; } for(int i=1;i<=tot;++i) { re[i]=0; } tot=0; sta=0; clr.clear(); } }; // 虚树 int tot,edges=1; int totsize,rt1,rt2,mx,ed,lsc,rsc; int hd[N<<2],vis[N<<3],size[N<<2]; struct Edge { int to,w,nex; }e[N<<3]; struct Node { int u,dis,val; Node(int u=0,int dis=0,int val=0):u(u),dis(dis),val(val){} }L[N<<2],R[N<<2]; void add_div(int x,int y,int z) { e[++edges].nex=hd[x],hd[x]=edges,e[edges].to=y,e[edges].w=z; } void Build_Tree(int x,int fa) { int ff=0; for(int i=0;i<lcs.G[x].size();++i) { int y=lcs.G[x][i].to; if(y==fa) continue; if(!ff) { ff=x; add_div(ff,y,lcs.G[x][i].w); add_div(y,ff,lcs.G[x][i].w); } else { ++tot; add_div(ff,tot,0); add_div(tot,ff,0); add_div(tot,y,lcs.G[x][i].w); add_div(y,tot,lcs.G[x][i].w); ff=tot; } Build_Tree(y,x); } } void find_edge(int x,int fa) { size[x]=1; for(int i=hd[x];i;i=e[i].nex) { int y=e[i].to; if(y==fa||vis[i]) continue; find_edge(y,x); int now=max(size[y],totsize-size[y]); if(now<mx) { mx=now; ed=i; rt1=y; rt2=x; } size[x]+=size[y]; } } void get_node(int x,int fa,int dep,int ty) { if(ty==1) { if(lcs.sam_str[x]) L[++lsc]=Node(lcs.sam_str[x],lcs.depth[x]-dep,dep); } else { if(lcs.sam_str[x]) R[++rsc]=Node(lcs.sam_str[x],lcs.depth[x]-dep,dep); } for(int i=hd[x];i;i=e[i].nex) { int y=e[i].to; if(vis[i]||y==fa) continue; get_node(y,x,dep+e[i].w,ty); } } void Divide_And_conquer(int x) { if(totsize==1) return; mx=inf; rt1=rt2=ed=0; find_edge(x,0); vis[ed]=vis[ed^1]=1; lsc=rsc=0; get_node(rt1,0,0,1); get_node(rt2,0,0,2); W=(ull)e[ed].w; ull tmp=ans; for(int i=1;i<=lsc;++i) vir::_new(lcp.str_sam[L[i].u],L[i].dis,1); for(int i=1;i<=rsc;++i) vir::_new(lcp.str_sam[R[i].u],R[i].dis,2); vir::solve(); int tmprt1=rt1,tmprt2=rt2; int sizert1=size[rt1],sizert2=totsize-size[rt1]; totsize=sizert1; Divide_And_conquer(tmprt1); totsize=sizert2; Divide_And_conquer(tmprt2); } int main() { // IO::setIO("input"); int i,j,n; scanf("%s%d%d",S+1,&K1,&K2); n=strlen(S+1); lcp.Initialize(); lcs.Initialize(); for(i=1;i<=n;++i) { lcs.extend(S[i]-'a'); lcp.extend(S[n-i+1]-'a'); } lcs.Build_LCS(); lcp.Build_LCP(); tot=lcs.tot; Build_Tree(1,0); vir::Initialize(); totsize=tot; Divide_And_conquer(1); printf("%llu ",ans); return 0; }