http://poj.org/problem?id=3415
题意:求两个字符串长度不小于k的公共子串数量
两个字符串用特殊字符连起来
后缀数组求出height数组
从大到小枚举,并查集合并
记录每一组 特殊字符前有多少个,特殊字符后有多少个,合并的贡献是 两者的乘积*(当前height-m+1)
#include<cstdio> #include<vector> #include<cstring> #include<iostream> #include<algorithm> using namespace std; #define N 100001 int n1,n,m; char s[N<<1]; int p=0,q=1; int v[N<<1]; int sa[2][N<<1],rk[2][N<<1],height[N<<1]; int fa[N<<1],siza[N<<1],sizb[N<<1]; vector<int>V[N<<1]; long long ans; void mul(int k,int *sa,int *rk,int *SA,int *RK) { for(int i=1;i<=n;++i) v[rk[sa[i]]]=i; for(int i=n;i;--i) if(sa[i]>k) SA[v[rk[sa[i]-k]]--]=sa[i]-k; for(int i=n-k+1;i<=n;++i) SA[v[rk[i]]--]=i; for(int i=1;i<=n;++i) RK[SA[i]]=RK[SA[i-1]]+(rk[SA[i]]!=rk[SA[i-1]]||rk[SA[i]+k]!=rk[SA[i-1]+k]); } void presa() { memset(v,0,sizeof(v)); for(int i=1;i<=n;++i) v[s[i]]++; for(int i=1;i<=130;++i) v[i]+=v[i-1]; for(int i=1;i<=n;++i) sa[p][v[s[i]]--]=i; for(int i=1;i<=n;++i) rk[p][sa[p][i]]=rk[p][sa[p][i-1]]+(s[sa[p][i-1]]!=s[sa[p][i]]); for(int k=1;k<n;k<<=1,swap(p,q)) mul(k,sa[p],rk[p],sa[q],rk[q]); } void get_height() { int j; for(int k=0,i=1;i<=n;++i) { j=sa[p][rk[p][i]-1]; while(s[j+k]==s[i+k]) k++; height[rk[p][i]]=k; if(k) k--; } } int find(int i) { return fa[i]==i ? i : fa[i]=find(fa[i]); } void unionn(int x,int y,int i) { x=find(x); y=find(y); ans+=1LL*siza[x]*sizb[y]*(i-m+1); ans+=1LL*sizb[x]*siza[y]*(i-m+1); siza[y]+=siza[x]; sizb[y]+=sizb[x]; fa[x]=y; } void solve() { for(int i=1;i<=n;++i) fa[i]=i; for(int i=1;i<=n1;++i) siza[i]=1,sizb[i]=0; for(int i=n1+2;i<=n;++i) sizb[i]=1,siza[i]=0; int mx=0; for(int i=2;i<=n;++i) V[height[i]].push_back(i),mx=max(mx,height[i]); int s,w; ans=0; for(int i=n;i>=m;--i) { s=V[i].size(); for(int j=0;j<s;++j) { w=V[i][j]; if(find(sa[p][w-1])!=find(sa[p][w])) unionn(sa[p][w-1],sa[p][w],i); if(w<n && height[w+1]>=i && find(sa[p][w+1])!=find(sa[p][w])) unionn(sa[p][w+1],sa[p][w],i); } } cout<<ans<<' '; for(int i=0;i<=mx;++i) V[i].clear(); } int main() { while(1) { scanf("%d",&m); if(!m) return 0; scanf("%s",s+1); n1=n=strlen(s+1); s[n+1]=char('a'+26); scanf("%s",s+n+2); n+=strlen(s+n+2)+1; presa(); get_height(); solve(); } }