题目:http://poj.org/problem?id=3415
先用后缀数组处理出 ht[i];
用单调栈维护当前位置 ht[i] 对之前的 ht[j] 取 min 的结果,也就是当前的后缀与之前后缀的LCP,其中长度 >= K 的加到答案;
因为单调栈中是一段一段阶梯状的,只存了一段端点的位置,所以再记录一个 cnt 表示这一段的长度,算贡献时乘上 cnt;
因为是两个串之间,所以先统计 B 在 A 排名前的答案,再重复一遍统计 A 在 B 排名前的答案;
但是 ht[i] 是 sa[i] 和 sa[i-1] 的LCP,所以 ht[i] 是否计入贡献应该考虑 i-1 位置...突然变得很麻烦,不太会弄了...
于是参考了一下TJ(囧),原来就是判断一下 i-1 是否要被统计,如果要统计 B 而 i-1 是 B 中的,就把 ht[i] 也累加到 sum 中;
还有一个很好的操作是如果 ht[i] < K,那么取 min 显然都会取成 < K 的,没贡献了,所以直接 sum=0 , top = 0,就省去了 max(0,ht[i]-K+1) 的分类麻烦。
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; int const xn=1e5+5,xxn=(xn<<1);//xxn int n,m,tax[xxn],sa[xxn],rk[xxn],tp[xxn],ht[xxn],sta[xxn],top,cnt[xxn]; char a[xn],b[xn],s[xxn]; void Rsort() { for(int i=1;i<=m;i++)tax[i]=0; for(int i=1;i<=n;i++)tax[rk[tp[i]]]++; for(int i=1;i<=m;i++)tax[i]+=tax[i-1]; for(int i=n;i;i--)sa[tax[rk[tp[i]]]--]=tp[i]; } void work() { for(int i=1;i<=n;i++)rk[i]=s[i],tp[i]=i; Rsort(); for(int k=1;k<=n;k<<=1) { int num=0; for(int i=n-k+1;i<=n;i++)tp[++num]=i; for(int i=1;i<=n;i++) if(sa[i]>k)tp[++num]=sa[i]-k; Rsort(); memcpy(tp,rk,sizeof rk);//swap(rk,tp); rk[sa[1]]=1; num=1; for(int i=2;i<=n;i++) rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+k]==tp[sa[i-1]+k])?num:++num; if(num==n)break; m=num; } } void get() { int k=0; ht[1]=0; for(int i=1;i<=n;i++) { if(rk[i]==1)continue; if(k)k--; int j=sa[rk[i]-1]; while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k])k++; ht[rk[i]]=k; } } int main() { int K; while(1) { scanf("%d",&K); if(!K)return 0; scanf("%s",a+1); int l1=strlen(a+1); scanf("%s",b+1); int l2=strlen(b+1); n=l1+l2+1; for(int i=1;i<=l1;i++)s[i]=a[i]; s[l1+1]='z'+1; for(int i=1;i<=l2;i++)s[l1+1+i]=b[i]; m=125; work(); get(); ll ans=0; ll sum=0; top=0; for(int i=1,y;i<=n;i++) { cnt[i]=0; if(ht[i]<K){top=0; sum=0; continue;}//min<K while(ht[i]<ht[y=sta[top]]&&top) { sum-=(ll)cnt[y]*(ht[y]-K+1); sum+=(ll)cnt[y]*(ht[i]-K+1);// top--; cnt[i]+=cnt[y]; } sta[++top]=i; if(sa[i-1]>l1+1)sum+=ht[i]-K+1,cnt[i]++;//cal(i-1):ht[i] if(sa[i]<=l1)ans+=sum; } sum=0; top=0; for(int i=1,y;i<=n;i++) { cnt[i]=0; if(ht[i]<K){top=0; sum=0; continue;}//min<K while(ht[i]<ht[y=sta[top]]&&top) { sum-=(ll)cnt[y]*(ht[y]-K+1); sum+=(ll)cnt[y]*(ht[i]-K+1); top--; cnt[i]+=cnt[y]; } sta[++top]=i; if(sa[i-1]<=l1)sum+=ht[i]-K+1,cnt[i]++; if(sa[i]>l1+1)ans+=sum; } printf("%lld ",ans); } return 0; }