树上后缀数组模板题:
https://www.codechef.com/problems/DIFTRIP
//#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math") //#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native") //#include <immintrin.h> //#include <emmintrin.h> #include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define me(x) memset(x,0,sizeof(x)) #define IL inline #define rint register int inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } char ss[1<<24],*A=ss,*B=ss; IL char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T>void maxa(T &x,T y) { if (y>x) x=y; } template<class T>void mina(T &x,T y) { if (y<x) x=y; } template<class T>void read(T &x) { int f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=(c^48); while(c=gc(),c>47&&c<58) x=x*10+(c^48); x*=f; } const int mo=1e9+7; ll fsp(int x,int y) { if (y==1) return x; ll ans=fsp(x,y/2); ans=ans*ans%mo; if (y%2==1) ans=ans*x%mo; return ans; } struct cp { ll x,y; cp operator +(cp B) { return (cp){x+B.x,y+B.y}; } cp operator -(cp B) { return (cp){x-B.x,y-B.y}; } ll operator *(cp B) { return x*B.y-y*B.x; } int half() { return y < 0 || (y == 0 && x < 0); } }; struct re{ int a,b,c; }; #define ull unsigned long long const int N=2e5; const ull base=23333; vector<int> ve[N]; ull ba[N],gg[20][N]; int bz[20][N],dep[N],a[N]; int x[N],y[N],sa[N],h[N],rk[N],c[N],xx[N],n; void dfs(int x,int y) { dep[x]=dep[y]+1; bz[0][x]=y; gg[0][x]=a[x]; rep(i,1,19) bz[i][x]=bz[i-1][bz[i-1][x]]; rep(i,1,19) gg[i][x]=gg[i-1][x]+gg[i-1][bz[i-1][x]]*ba[(1<<(i-1))]; for (auto v:ve[x]) if (v!=y) { dfs(v,x); } } int lcp(int x,int y) { int xx=x,yy=y; dep(i,19,0) if (gg[i][xx]==gg[i][yy]) xx=bz[i][xx],yy=bz[i][yy]; return min(dep[x]-dep[xx],dep[y]-dep[yy]); //会有跳到根的情况 } void asa(int n) { int p=0; rep(i,1,n) c[i]=0; rep(i,1,n) c[x[i]=a[i]]++; rep(i,1,n) c[i]+=c[i-1]; dep(i,n,1) sa[c[x[i]]--]=i; for (int i=1,k=0;i<=n;i<<=1,k++) { rep(j,1,n) xx[j]=x[bz[k][j]]; rep(j,1,n) c[j]=0; rep(j,1,n) c[xx[j]]++; rep(j,1,n) c[j]+=c[j-1]; dep(j,n,1) y[c[xx[j]]--]=j; // 这里处理也要不同,因为按原先处理会有一些y相同 rep(j,1,n) c[j]=0; rep(j,1,n) c[x[y[j]]]++; rep(j,1,n) c[j]+=c[j-1]; dep(j,n,1) sa[c[x[y[j]]]--]=y[j]; swap(x,y); x[sa[1]]=1; p=2; rep(j,2,n) x[sa[j]]=y[sa[j]]==y[sa[j-1]]&&y[bz[k][sa[j]]]==y[bz[k][sa[j-1]]]?p-1:p++; } rep(i,1,n) rk[sa[i]]=i; rep(i,1,n) { h[i]=lcp(sa[i-1],sa[i]); //因为不满足h[rk[i]>=h[rk[i-1]]-1 所以只能倍增hash计算 } } int main() { freopen("1.in","r",stdin); freopen("1.out","w",stdout); ios::sync_with_stdio(false); cin>>n; rep(i,1,n-1) { int x,y; cin>>x>>y; ve[x].push_back(y); ve[y].push_back(x); } rep(i,1,n) a[i]=ve[i].size(); ba[0]=1; rep(i,1,n) ba[i]=ba[i-1]*base; dfs(1,0); asa(n); ll ans=0; for(int i=1;i<=n;i++) ans+=dep[sa[i]]-h[i]; cout<<ans<<endl; return 0; }