题意:给定一棵树,树的每一个结点都有一个[0..c-1]的数字,问本质不同的由任意两点的路径上的数字组成的字符串的个数
n<=1e5,c<=10
度为1的结点不超过20个
思路:ZJOI2015都4年了……时间真快
考虑任意一个答案串,可以被看成以一个叶子节点为根的Trie树的子串
又因为叶子结点个数<=20,可以暴力建SAM在上面跑,最后统计方案数
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef unsigned int uint; 5 typedef unsigned long long ull; 6 typedef pair<int,int> PII; 7 typedef pair<ll,ll> Pll; 8 typedef vector<int> VI; 9 typedef vector<PII> VII; 10 typedef pair<ll,int>P; 11 #define N 4000010 12 #define M 210000 13 #define fi first 14 #define se second 15 #define MP make_pair 16 #define pi acos(-1) 17 #define mem(a,b) memset(a,b,sizeof(a)) 18 #define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++) 19 #define per(i,a,b) for(int i=(int)a;i>=(int)b;i--) 20 #define lowbit(x) x&(-x) 21 #define Rand (rand()*(1<<16)+rand()) 22 #define id(x) ((x)<=B?(x):m-n/(x)+1) 23 #define ls p<<1 24 #define rs p<<1|1 25 26 const int MOD=998244353,inv2=(MOD+1)/2; 27 double eps=1e-6; 28 ll INF=1e18; 29 ll inf=5e13; 30 int dx[4]={-1,1,0,0}; 31 int dy[4]={0,0,-1,1}; 32 33 int head[M],vet[M],nxt[M],a[M],d[M],tot,n,c; 34 35 int read() 36 { 37 int v=0,f=1; 38 char c=getchar(); 39 while(c<48||57<c) {if(c=='-') f=-1; c=getchar();} 40 while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar(); 41 return v*f; 42 } 43 44 int add(int a,int b) 45 { 46 nxt[++tot]=head[a]; 47 vet[tot]=b; 48 head[a]=tot; 49 } 50 51 struct sam 52 { 53 int cnt,t[N][10], 54 st[N],f[N],bl[N],b[N],sz[N]; 55 56 sam() 57 { 58 cnt=1; 59 } 60 61 int add(int p,int x) 62 { 63 int np,nq,q; 64 st[np=++cnt]=st[p]+1; 65 while(!t[p][x]) 66 { 67 t[p][x]=np; 68 p=f[p]; 69 } 70 if(!p) f[np]=1; 71 else if(st[p]+1==st[q=t[p][x]]) f[np]=q; 72 else 73 { 74 st[nq=++cnt]=st[p]+1; 75 memcpy(t[nq],t[q],sizeof t[q]); 76 f[nq]=f[q]; 77 f[q]=f[np]=nq; 78 while(t[p][x]==q) 79 { 80 t[p][x]=nq; 81 p=f[p]; 82 } 83 } 84 return np; 85 } 86 87 void solve() 88 { 89 ll ans=0; 90 rep(i,1,cnt) ans+=st[i]-st[f[i]]; 91 printf("%lld ",ans); 92 } 93 94 }sam; 95 96 void dfs(int u,int fa,int p) 97 { 98 int x=sam.add(p,a[u]); 99 int e=head[u]; 100 while(e) 101 { 102 int v=vet[e]; 103 if(v!=fa) dfs(v,u,x); 104 e=nxt[e]; 105 } 106 } 107 108 109 int main() 110 { 111 //freopen("1.in","r",stdin); 112 //freopen("1.out","w",stdout); 113 n=read(),c=read(); 114 rep(i,1,n) a[i]=read(); 115 rep(i,1,n) d[i]=0; 116 rep(i,1,n-1) 117 { 118 int x=read(),y=read(); 119 add(x,y); 120 add(y,x); 121 d[x]++; d[y]++; 122 } 123 rep(i,1,n) 124 if(d[i]==1) dfs(i,0,1); 125 sam.solve(); 126 return 0; 127 }
建立部分换了个专门给trie树的板子又写了一次
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef unsigned int uint; 5 typedef unsigned long long ull; 6 typedef pair<int,int> PII; 7 typedef pair<ll,ll> Pll; 8 typedef vector<int> VI; 9 typedef vector<PII> VII; 10 typedef pair<ll,int>P; 11 #define N 4000010 12 #define M 210000 13 #define fi first 14 #define se second 15 #define MP make_pair 16 #define pi acos(-1) 17 #define mem(a,b) memset(a,b,sizeof(a)) 18 #define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++) 19 #define per(i,a,b) for(int i=(int)a;i>=(int)b;i--) 20 #define lowbit(x) x&(-x) 21 #define Rand (rand()*(1<<16)+rand()) 22 #define id(x) ((x)<=B?(x):m-n/(x)+1) 23 #define ls p<<1 24 #define rs p<<1|1 25 26 const int MOD=998244353,inv2=(MOD+1)/2; 27 double eps=1e-6; 28 ll INF=1e18; 29 ll inf=5e13; 30 int dx[4]={-1,1,0,0}; 31 int dy[4]={0,0,-1,1}; 32 33 int head[M],vet[M],nxt[M],a[M],d[M],tot,n,c,q,nq,np; 34 35 int read() 36 { 37 int v=0,f=1; 38 char c=getchar(); 39 while(c<48||57<c) {if(c=='-') f=-1; c=getchar();} 40 while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar(); 41 return v*f; 42 } 43 44 int add(int a,int b) 45 { 46 nxt[++tot]=head[a]; 47 vet[tot]=b; 48 head[a]=tot; 49 } 50 51 struct sam 52 { 53 int cnt,ch[N][10], 54 st[N],fa[N],bl[N],b[N],sz[N]; 55 56 sam() 57 { 58 cnt=1; 59 } 60 61 int add(int p,int x) 62 { 63 if(ch[p][x]) 64 { 65 q=ch[p][x]; 66 if(st[q]==st[p]+1) return q; 67 else 68 { 69 st[nq=++cnt]=st[p]+1; 70 memcpy(ch[nq],ch[q],sizeof ch[q]); 71 //t[nq]=t[q]; 72 fa[nq]=fa[q]; 73 fa[q]=nq; 74 while(ch[p][x]==q) 75 { 76 ch[p][x]=nq; 77 p=fa[p]; 78 } 79 return nq; 80 } 81 } 82 else 83 { 84 st[np=++cnt]=st[p]+1; 85 while(p&&!ch[p][x]) 86 { 87 ch[p][x]=np; 88 p=fa[p]; 89 } 90 if(!p) fa[np]=1; 91 else 92 { 93 int q=ch[p][x]; 94 if(st[q]==st[p]+1) fa[np]=q; 95 else 96 { 97 nq=++cnt; st[nq]=st[p]+1; 98 memcpy(ch[nq],ch[q],sizeof ch[q]); 99 fa[nq]=fa[q]; 100 fa[q]=fa[np]=nq; 101 while(ch[p][x]==q) 102 { 103 ch[p][x]=nq; 104 p=fa[p]; 105 } 106 } 107 } 108 } 109 return np; 110 } 111 112 void solve() 113 { 114 ll ans=0; 115 rep(i,1,cnt) ans+=st[i]-st[fa[i]]; 116 printf("%lld ",ans); 117 } 118 119 }sam; 120 121 void dfs(int u,int fa,int p) 122 { 123 int x=sam.add(p,a[u]); 124 int e=head[u]; 125 while(e) 126 { 127 int v=vet[e]; 128 if(v!=fa) dfs(v,u,x); 129 e=nxt[e]; 130 } 131 } 132 133 134 int main() 135 { 136 n=read(),c=read(); 137 rep(i,1,n) a[i]=read(); 138 rep(i,1,n) d[i]=0; 139 rep(i,1,n-1) 140 { 141 int x=read(),y=read(); 142 add(x,y); 143 add(y,x); 144 d[x]++; d[y]++; 145 } 146 rep(i,1,n) 147 if(d[i]==1) dfs(i,0,1); 148 sam.solve(); 149 return 0; 150 }