之前看过几次后缀自动机,然后因为人太蠢都没看懂。
最近重新填坑TAT。。。
BZOJ4032: [HEOI2015]最短不公共子串
建出后缀自动机和序列自动机,然后我们知道自动机上每一条路径都相当于一个子串(子序列),这样只要从根节点开始bfs一遍,找到A有而B没有的,那就是字典序最小的辣。
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<queue> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 4005 #define ll long long #define inf int(1e9) using namespace std; char sa[maxn],sb[maxn]; int f[maxn][maxn]; int n,m; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } struct data{ struct sam{int l,fa,ch[30];} sam[maxn]; int tot,root,last,head[30]; void init(){ clr(sam,0); tot=1; clr(head,0); } void extend(int c){ int p,np,q,nq; p=last,np=++tot; last=np; sam[np].l=sam[p].l+1; for (;p&&!sam[p].ch[c];p=sam[p].fa) sam[p].ch[c]=np; if (!p) sam[np].fa=1; else { q=sam[p].ch[c]; if (sam[p].l+1==sam[q].l) sam[np].fa=q; else { nq=++tot; sam[nq].l=sam[p].l+1; memcpy(sam[nq].ch,sam[q].ch,sizeof(sam[q].ch)); sam[nq].fa=sam[q].fa; sam[np].fa=sam[q].fa=nq; for (;sam[p].ch[c]==q;p=sam[p].fa) sam[p].ch[c]=nq; } } } void sambuild(int n,char s[]){ init(); tot=last=1; rep(i,1,n) extend(s[i]-'a'); } void quebuild(int n,char s[]){ int o,p,c; init(); rep(i,0,25) head[i]=1; rep(i,1,n){ o=++tot; c=s[i]-'a'; rep(j,0,25) for (p=head[j];p&&!sam[p].ch[c];p=sam[p].fa) sam[p].ch[c]=o; sam[o].fa=head[c]; head[c]=o; } } } A,B; struct node{int x,y;}; int solve(){ queue<node> q; clr(f,0); q.push((node){1,1}); f[1][1]=0; while (!q.empty()){ int ux=q.front().x,uy=q.front().y; q.pop(); rep(i,0,25){ if (!A.sam[ux].ch[i]) continue; if (!B.sam[uy].ch[i]) return f[ux][uy]+1; int vx=A.sam[ux].ch[i],vy=B.sam[uy].ch[i]; if (!f[vx][vy]) q.push((node){vx,vy}),f[vx][vy]=f[ux][uy]+1; } } return -1; } int main(){ scanf("%s",sa+1); scanf("%s",sb+1); n=strlen(sa+1); m=strlen(sb+1); A.sambuild(n,sa); B.sambuild(m,sb); printf("%d ",solve()); B.quebuild(m,sb); printf("%d ",solve()); A.quebuild(n,sa); B.sambuild(m,sb); printf("%d ",solve()); B.quebuild(m,sb); printf("%d ",solve()); return 0; }
Luv Letter 【弱省胡策】Round #0
这道和上一道差不多的,上面是求字典序最小,这道是求出现次数,于是我们可以对AB都建自动机。然后从根开始拓扑一遍就可以得到每个节点后面有多少个子串,然后再记忆化搜索一遍枚举所有子串(子序列)统计答案就好了。
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<queue> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 4005 #define ll long long #define inf int(1e9) #define mm 1000000007 using namespace std; char sa[maxn],sb[maxn]; ll f[maxn][maxn]; int vis[maxn][maxn]; int n,m; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } struct data{ struct sam{int l,fa,ch[30];ll sz;} sam[maxn]; int tot,root,last,head[30]; void init(){ clr(sam,0); tot=1; clr(head,0); } void extend(int c){ int p,np,q,nq; p=last,np=++tot; last=np; sam[np].l=sam[p].l+1; for (;p&&!sam[p].ch[c];p=sam[p].fa) sam[p].ch[c]=np; if (!p) sam[np].fa=1; else { q=sam[p].ch[c]; if (sam[p].l+1==sam[q].l) sam[np].fa=q; else { nq=++tot; sam[nq].l=sam[p].l+1; memcpy(sam[nq].ch,sam[q].ch,sizeof(sam[q].ch)); sam[nq].fa=sam[q].fa; sam[np].fa=sam[q].fa=nq; for (;sam[p].ch[c]==q;p=sam[p].fa) sam[p].ch[c]=nq; } } } void dfs(int u){ if (!u||sam[u].sz) return; sam[u].sz=1; rep(i,0,25) { int v=sam[u].ch[i]; if (!v) continue; if (!sam[v].sz) dfs(v); sam[u].sz=(sam[u].sz+sam[v].sz)%mm; } } void sambuild(int n,char s[]){ init(); tot=last=1; rep(i,1,n) extend(s[i]-'a'); dfs(1); } void quebuild(int n,char s[]){ int o,p,c; init(); rep(i,0,25) head[i]=1; rep(i,1,n){ o=++tot; c=s[i]-'a'; rep(j,0,25) for (p=head[j];p&&!sam[p].ch[c];p=sam[p].fa) sam[p].ch[c]=o; sam[o].fa=head[c]; head[c]=o; } dfs(1); } } A,B; struct node{int x,y;}; ll dfs(int x,int y){ if (vis[x][y]) return f[x][y]; vis[x][y]=1; if (!y) return f[x][y]=A.sam[x].sz; f[x][y]=0; rep(i,0,25){ int vx=A.sam[x].ch[i],vy=B.sam[y].ch[i]; if (!vx) continue; f[x][y]=(f[x][y]+dfs(vx,vy))%mm; } return f[x][y]=f[x][y]; } ll solve(){ clr(vis,0); return dfs(1,1); } int main(){ scanf("%s",sa+1); scanf("%s",sb+1); n=strlen(sa+1); m=strlen(sb+1); A.sambuild(n,sa); B.sambuild(m,sb); printf("%lld ",solve()); B.quebuild(m,sb); printf("%lld ",solve()); A.quebuild(n,sa); B.sambuild(m,sb); printf("%lld ",solve()); B.quebuild(m,sb); printf("%lld ",solve()); return 0; }
3998: [TJOI2015]弦论
字典序k大问题,对于多个相同子串算一个的那么按上面的方法扫一遍就行了,那多个相同子串算多个就要先算出每个节点的|right|,其实就是这个串的出现次数。(求right大小比较重要吧,蒟蒻理解了好久TAT。。。然后记住要倒序添加到fail上面。
#include<cstring> #include<cstdio> #include<iostream> #include<algorithm> #include<cmath> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 1000500 using namespace std; int last,tot,n,T,K; char s[maxn]; int fa[maxn],ch[maxn][31],l[maxn],q[maxn],b[maxn],sum[maxn],val[maxn],vis[maxn]; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } void expand(int c){ int p,np,q,nq; p=last; last=np=++tot; l[np]=l[p]+1; for (;p&&!ch[p][c];p=fa[p]) ch[p][c]=np; if (!p) fa[np]=1; else { q=ch[p][c]; if (l[q]==l[p]+1) fa[np]=q; else { nq=++tot; l[nq]=l[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[nq])); fa[nq]=fa[q]; fa[np]=fa[q]=nq; for (;p&&ch[p][c]==q;p=fa[p]) ch[p][c]=nq; } } val[np]=1; } void dfs(int u,int k){ if (k>=sum[u]) return; k-=val[u]; if (!k) return; rep(j,0,25) if (ch[u][j]){ int v=ch[u][j]; if (sum[v]>=k){ putchar(j+'a'); dfs(v,k); return ; } k-=sum[v]; } } int main(){ scanf("%s",s+1); n=strlen(s+1); tot=last=1; rep(i,1,n) expand(s[i]-'a'); T=read(); K=read(); rep(i,1,tot) b[l[i]]++; rep(i,1,n) b[i]+=b[i-1]; rep(i,1,tot) q[b[l[i]]--]=i; down(i,tot,1) { int t=q[i]; if (T==1) val[fa[t]]+=val[t]; else val[t]=1; } val[1]=0; down(i,tot,1){ int t=q[i]; sum[t]=val[t]; rep(j,0,25) sum[t]+=sum[ch[t][j]]; } if (sum[1]<K) {puts("-1"); return 0;} dfs(1,K); return 0; }
3238: [Ahoi2013]差异
逆序建出sam,它的fail树就是后缀树,然后dfs扫一遍就好了。
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<queue> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 1005000 #define ll long long #define inf int(1e9) #define mm 1000000007 using namespace std; char s[maxn]; ll ans; int n,m; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } struct data{ struct sam{int l,fa,ch[30];ll sz;} sam[maxn]; int tot,root,last; void init(){ clr(sam,0); tot=1; } void extend(int c){ int p,np,q,nq; p=last,np=++tot; last=np; sam[np].l=sam[p].l+1; for (;p&&!sam[p].ch[c];p=sam[p].fa) sam[p].ch[c]=np; if (!p) sam[np].fa=1; else { q=sam[p].ch[c]; if (sam[p].l+1==sam[q].l) sam[np].fa=q; else { nq=++tot; sam[nq].l=sam[p].l+1; memcpy(sam[nq].ch,sam[q].ch,sizeof(sam[q].ch)); sam[nq].fa=sam[q].fa; sam[np].fa=sam[q].fa=nq; for (;sam[p].ch[c]==q;p=sam[p].fa) sam[p].ch[c]=nq; } } sam[np].sz=1; } void sambuild(int n,char s[]){ tot=last=1; down(i,n,1) extend(s[i]-'a'); } } A,B; ll sum[maxn]; int l[maxn],head[maxn],tot; struct node{int obj,pre; }e[maxn]; void insert(int x,int y){ e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot; } void dfs(int u){ for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; dfs(v); ans-=2LL*l[u]*sum[u]*sum[v]; sum[u]+=sum[v]; } } int main(){ scanf("%s",s+1); n=strlen(s+1); ans=1LL*n*(n-1)*(n+1)/2; A.sambuild(n,s); rep(i,1,A.tot) l[i]=A.sam[i].l,sum[i]=A.sam[i].sz,insert(A.sam[i].fa,i); dfs(1); printf("%lld ",ans); return 0; }
Problem - 653F - Codeforces
题意是给你一个括号序列,求合法子串个数。
SAM+RMQ
建出sam,把 fail 树建出来之后就得到了一堆前缀和这堆前缀的lcp(其实就是得到了逆序的后缀树)
对于 sam 上的一个点 i ,它的贡献就是以 pos[i] 为右端点的合法子串个数减去以 pos[fa[i]] 为右断点的合法子串个数。
比如 ()() ,现在要算第 4 个点的贡献,它的贡献就是 2-1=1 。
然后对于这个贡献,维护一个栈之类的左括号就压进去,右括号就拿出最后一个在栈里的括号并弹出栈,然后写个 RMQ 维护一下。
f[i][j] 表示位置 i 往左数共有 2^j 个封闭括号区间(比如 (())() 有两个区间)的最左边的 ( 的位置,显然有 f[i][j]=f[f[i][j-1]-1][j-1]
#include<cstring> #include<iostream> #include<cstdio> #include<algorithm> #include<queue> #include<cmath> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 1005000 #define inf 2000000000 #define mm 1024523 #define eps 1e-6 #define uint unsigned int #define ll long long using namespace std; int pos[maxn],go[maxn][2],fa[maxn],len[maxn]; int n,last,tot,s[maxn],f[maxn][22]; char ch[maxn]; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)){x=x*10+ch-'0'; ch=getchar();} return x*f; } void add(int c){ int p,np,q,nq; p=last; last=np=++tot; len[np]=pos[np]=len[p]+1; for (;p&&!go[p][c];p=fa[p]) go[p][c]=np; if (!p) fa[np]=1; else { q=go[p][c]; if (len[q]==len[p]+1) { fa[np]=q; return; } else { nq=++tot; len[nq]=len[p]+1; pos[nq]=pos[q]; memcpy(go[nq],go[q],sizeof(go[q])); fa[nq]=fa[q]; fa[q]=fa[np]=nq; for (;p&&go[p][c]==q;p=fa[p]) go[p][c]=nq; } } } void pre(){ vector<int> q; rep(i,1,n){ if (!s[i]) q.push_back(i); else { if (!q.empty()) { int pos=q.back(); q.pop_back(); f[i][0]=pos; for (int j=1;j<=20;j++) if (f[i][j-1]>0) f[i][j]=f[f[i][j-1]-1][j-1]; } } } } int ask(int pos,int len){ int ans=0; down(i,20,0) if (f[pos][i]>0&&pos-f[pos][i]+1<=len) { ans+=1<<i; len-=pos-f[pos][i]+1; pos=f[pos][i]-1; } return ans; } int main(){ // freopen("in.txt","r",stdin); n=read(); last=tot=1; scanf("%s",ch+1); rep(i,1,n) if (ch[i]=='(') s[i]=0; else s[i]=1; rep(i,1,n) add(s[i]); pre(); ll ans=0; rep(i,2,tot) ans+=1LL*(ask(pos[i],len[i])-ask(pos[fa[i]],len[fa[i]])); printf("%lld ",ans); return 0; }
广义后缀自动机
(其实就是在trie上建一棵后缀自动机辣。有多个串的时候记住每次都要让last=1,否则串与串会连在一起形成新的串。
如果当前要加入的点已经有了且l[q]=l[p]+1,那么np=q,否则就新增一个点。。如果当前加入的点并没有,那就跟以前维护SAM一样就好了。。(我们可以花大量常数开map偷懒)
void expand(int x,int y){ int p,q,np,nq; p=last; if ((q=go[p][x])) { if (l[q]==l[p]+1) last=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; last=nq; } } else { np=++tot; l[np]=l[p]+1; for (;p&&!go[p][x];p=fa[p]) go[p][x]=np; if (!p) fa[np]=1; else { q=go[p][x]; if (l[q]==l[p]+1) fa[np]=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=fa[np]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; } } last=np; } }
BZOJ2780: [Spoj]8093 Sevenk Love Oimaster
(题面简直笑抽2333333
题意是给你n个串,m个询问串,每个询问串是n个串中多少个串的字串。。
首先我们把广义后缀自动机建出来,然后fail[i]->i,建立一棵树,可以得到SAM每个节点的dfs序。
对于询问我们在trie上跑得到末尾的那个点,那么它子树的信息就是答案了。
问题转化为一个区间上有多少个不同的数字。上个树状数组。
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<map> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 600500 #define ll long long #define inf int(1e9) #define mm 1000000007 #define low(x) x&(-x) using namespace std; char s[maxn]; int n,m; int ans[maxn],dfn[maxn][2],pos[maxn],t[maxn]; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } struct edge{ struct data{int obj,pre;}e[maxn]; int head[maxn],tot; void insert(int x,int y){ e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot; } }A,B; struct node{int x,y,id; }a[maxn]; struct SAM{ int tot,last,idx; int l[maxn],fa[maxn]; map<int,int> go[maxn]; void expand(int x,int y){ int p,q,np,nq; p=last; if ((q=go[p][x])) { if (l[q]==l[p]+1) last=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; last=nq; } } else { np=++tot; l[np]=l[p]+1; for (;p&&!go[p][x];p=fa[p]) go[p][x]=np; if (!p) fa[np]=1; else { q=go[p][x]; if (l[q]==l[p]+1) fa[np]=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=fa[np]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; } } last=np; } B.insert(last,y); } void dfs(int u){ dfn[u][0]=++idx; pos[idx]=u; for (int j=A.head[u];j;j=A.e[j].pre){ int v=A.e[j].obj; if (!dfn[v][0]) dfs(v); } dfn[u][1]=idx; } void init(int len,int y){ last=1; rep(i,1,len) expand(s[i],y); } node find(int x){ scanf("%s",s+1); int len=strlen(s+1),now=1; rep(i,1,len) now=go[now][s[i]]; if (!now) return (node){2,1,x}; else return (node){dfn[now][0],dfn[now][1],x}; } }T; bool cmp(node a,node b){ return a.y<b.y; } int vis[maxn]; void add(int x,int y){ while (x<=T.tot) { t[x]+=y; x+=low(x); } } int ask(int x){ int ans=0; while (x){ ans+=t[x]; x-=low(x); } return ans; } int main(){ n=read(); m=read(); T.tot=1; T.idx=0; rep(i,1,n) { scanf("%s",s+1); T.init(strlen(s+1),i); } rep(i,1,T.tot) A.insert(T.fa[i],i); T.dfs(1); rep(i,1,m) a[i]=T.find(i); sort(a+1,a+1+m,cmp); int now=1; rep(i,1,T.tot){ for (int j=B.head[pos[i]];j;j=B.e[j].pre){ int v=B.e[j].obj; if (vis[v]) add(vis[v],-1); vis[v]=i; add(vis[v],1); } while (a[now].y==i) ans[a[now].id]=ask(a[now].y)-ask(a[now].x-1),now++; } rep(i,1,m) printf("%d ",ans[i]); return 0; }
BZOJ 3473: 字符串
题意:给定n个字符串,询问每个字符串有多少子串(不包括空串)是所有n个字符串中至少k个字符串的子串?
类似于上一题。。fail[i]->i,建树之后,把SAM每个节点的dfs序求出来。
按照右端点排序维护一个树状数组。
一个节点如果出现次数>=k,那么它的贡献是l[i]-l[fail[i]]
然后每个点的总贡献是它到根上的权值和。
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<map> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 200500 #define ll long long #define inf int(1e9) #define mm 1000000007 #define low(x) x&(-x) using namespace std; char s[maxn]; ll f[maxn]; int n,k; int a[maxn],dfn[maxn][2],pos[maxn],t[maxn]; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } struct edge{ struct data{int obj,pre;}e[maxn]; int head[maxn],tot; void insert(int x,int y){ e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot; } }A,B,C; struct SAM{ int tot,last,idx; int l[maxn],fa[maxn]; map<int,int> go[maxn]; void expand(int x,int y){ int p,q,np,nq; p=last; if ((q=go[p][x])) { if (l[q]==l[p]+1) last=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; last=nq; } } else { np=++tot; l[np]=l[p]+1; for (;p&&!go[p][x];p=fa[p]) go[p][x]=np; if (!p) fa[np]=1; else { q=go[p][x]; if (l[q]==l[p]+1) fa[np]=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=fa[np]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; } } last=np; } B.insert(last,y); } void dfs(int u){ dfn[u][0]=++idx; pos[idx]=u; for (int j=A.head[u];j;j=A.e[j].pre){ int v=A.e[j].obj; if (!dfn[v][0]) dfs(v); } dfn[u][1]=idx; } void init(int len,int y){ last=1; rep(i,1,len) expand(s[i],y),C.insert(y,last); } }T; int vis[maxn]; void add(int x,int y){ while (x<=T.tot) { t[x]+=y; x+=low(x); } } int ask(int x){ int ans=0; while (x){ ans+=t[x]; x-=low(x); } return ans; } void dfs2(int u){ for (int j=A.head[u];j;j=A.e[j].pre){ int v=A.e[j].obj; f[v]+=f[u]; dfs2(v); } } bool cmp(int a,int b){ return dfn[a][1]<dfn[b][1]; } int main(){ // freopen("in.txt","r",stdin); n=read(); k=read(); T.tot=1; T.idx=0; rep(i,1,n) { scanf("%s",s+1); T.init(strlen(s+1),i); } rep(i,1,T.tot) A.insert(T.fa[i],i),a[i]=i; T.dfs(1); sort(a+1,a+1+T.tot,cmp); int now=1; rep(i,1,T.tot){ for (int j=B.head[pos[i]];j;j=B.e[j].pre){ int v=B.e[j].obj; if (vis[v]) add(vis[v],-1); vis[v]=i; add(vis[v],1); } while (dfn[a[now]][1]==i) { f[a[now]]=ask(dfn[a[now]][1])-ask(dfn[a[now]][0]-1)>=k?T.l[a[now]]-T.l[T.fa[a[now]]]:0; now++; } } dfs2(1); rep(i,1,n) { ll ans=0; for (int j=C.head[i];j;j=C.e[j].pre){ int v=C.e[j].obj; ans+=f[v]; } printf("%lld",ans); if (i!=n) printf(" "); } return 0; }
BZOJ3926: [Zjoi20150]诸神眷顾的幻想乡
叶子节点很少嘛。。把每个度为1的点dfs一遍,建立出广义SAM(记住走u->v的时候要把last改回去)
只要把叶子都dfs一遍。就可以解决正反的问题辣。
每个节点的贡献是l[i]-l[fail[i]]
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<map> #include<cmath> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 4005000 #define ll long long #define inf 1152921504606846976 #define mm 1000000007 #define low(x) x&(-x) using namespace std; struct data{int obj,pre; }e[maxn*2]; int n,c,tot,d[maxn],a[maxn],head[maxn]; ll read(){ ll x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } struct SAM{ int l[maxn],fa[maxn]; map<int,int> go[maxn]; int last,tot; void expand(int x){ int p,q,np,nq; p=last; if ((q=go[p][x])){ if (l[q]==l[p]+1) last=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; last=nq; } } else { np=++tot; l[np]=l[p]+1; for (;p&&!go[p][x];p=fa[p]) go[p][x]=np; if (!p) fa[np]=1; else { q=go[p][x]; if (l[q]==l[p]+1) fa[np]=q; else { nq=++tot; l[nq]=l[p]+1; go[nq]=go[q]; fa[nq]=fa[q]; fa[q]=fa[np]=nq; for (;p&&go[p][x]==q;p=fa[p]) go[p][x]=nq; } } last=np; } } }A; void insert(int x,int y){ e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot; d[x]++; e[++tot].obj=x; e[tot].pre=head[y]; head[y]=tot; d[y]++; } void dfs(int u,int f){ if (f==0) A.last=1; A.expand(a[u]); int now=A.last; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=f) dfs(v,u); A.last=now; } } int main(){ n=read(); c=read(); rep(i,1,n) a[i]=read(); A.tot=1; rep(i,1,n-1) { int x=read(),y=read(); insert(x,y); } rep(i,1,n) if (d[i]==1) {dfs(i,0);} ll ans=0; rep(i,1,A.tot) ans+=1LL*(A.l[i]-A.l[A.fa[i]]); printf("%lld ",ans); return 0; }