应该算是远古时期的一道题了吧,不过感觉挺经典的。
题意是给出三一个字符串s,a,b,求以a开头b结尾的本质不同的字符串数。
由于n不算大,用hash就可以搞,不过这道题是存在复杂度$O(nlogn)$的做法的。
由于要求本质不同,所以可以考虑使用后缀数组来不重复地枚举字符串。
首先用两个不同的其他字符将s,a,b拼起来求后缀数组,这样就可以知道任意两个后缀的lcp了。然后将s中所有b出现的末尾位置置1,求个后缀和suf。将s中所有后缀按名次从小到大存到一个vector里。对于s中的每个后缀,设其名次为x,a的长度为la,b的长度为lb,若$lcp(x,rnk[ia])=la$,则其对答案的贡献为$suf[sa[i]+max(lcp(x,vec[i-1]),la-1,lb-1)]$。其中la-1和lb-1是为了保证字符串长度比a和b都大,lcp(x,vec[i-1])是为了保证不重复枚举,相当于没有将s与a,b拼起来时的height[x]。特别地,当x=0时为0。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1e4+10; 5 char buf[N]; 6 int s[N],sa[N],buf1[N],buf2[N],c[N],n,rnk[N],ht[N],ST[N][20],Log[N],ia,ib,m,la,lb,suf[N]; 7 void Sort(int* x,int* y,int m) { 8 for(int i=0; i<m; ++i)c[i]=0; 9 for(int i=0; i<n; ++i)++c[x[i]]; 10 for(int i=1; i<m; ++i)c[i]+=c[i-1]; 11 for(int i=n-1; i>=0; --i)sa[--c[x[y[i]]]]=y[i]; 12 } 13 void da(int* s,int n,int m=1000) { 14 int *x=buf1,*y=buf2; 15 s[n]=x[n]=y[n]=-1; 16 for(int i=0; i<n; ++i)x[i]=s[i],y[i]=i; 17 Sort(x,y,m); 18 for(int k=1; k<n; k<<=1) { 19 int p=0; 20 for(int i=n-k; i<n; ++i)y[p++]=i; 21 for(int i=0; i<n; ++i)if(sa[i]>=k)y[p++]=sa[i]-k; 22 Sort(x,y,m),p=1,y[sa[0]]=0; 23 for(int i=1; i<n; ++i)y[sa[i]]=x[sa[i-1]]==x[sa[i]]&&x[sa[i-1]+k]==x[sa[i]+k]?p-1:p++; 24 if(p==n)break; 25 swap(x,y),m=p; 26 } 27 } 28 void getht() { 29 for(int i=0; i<n; ++i)rnk[sa[i]]=i; 30 ht[0]=0; 31 for(int i=0,k=0; i<n; ++i) { 32 if(k)--k; 33 if(!rnk[i])continue; 34 for(; s[i+k]==s[sa[rnk[i]-1]+k]; ++k); 35 ht[rnk[i]]=k; 36 } 37 } 38 void initST() { 39 for(int i=1; i<n; ++i)ST[i][0]=ht[i]; 40 for(int j=1; (1<<j)<=n; ++j) 41 for(int i=1; i+(1<<j)-1<n; ++i) 42 ST[i][j]=min(ST[i][j-1],ST[i+(1<<(j-1))][j-1]); 43 } 44 int lcp(int l,int r) { 45 if(l==r)return n-sa[l]; 46 if(l>r)swap(l,r); 47 l++; 48 int k=Log[r-l+1]; 49 return min(ST[l][k],ST[r-(1<<k)+1][k]); 50 } 51 vector<int> vec; 52 int main() { 53 Log[0]=-1; 54 for(int i=1; i<N; ++i)Log[i]=Log[i>>1]+1; 55 scanf("%s",buf),m=strlen(buf); 56 for(int i=0; i<m; ++i)s[n++]=buf[i]; 57 s[n++]='z'+1,ia=n; 58 scanf("%s",buf),m=strlen(buf),la=m; 59 for(int i=0; i<m; ++i)s[n++]=buf[i]; 60 s[n++]='z'+2,ib=n; 61 scanf("%s",buf),m=strlen(buf),lb=m; 62 for(int i=0; i<m; ++i)s[n++]=buf[i]; 63 s[n]=0; 64 da(s,n),getht(),initST(); 65 for(int i=0; i<ia-1; ++i)if(lcp(rnk[i],rnk[ib])==lb)suf[i+lb-1]=1; 66 for(int i=ia-3; i>=0; --i)suf[i]+=suf[i+1]; 67 for(int i=0; i<ia-1; ++i)vec.push_back(rnk[i]); 68 sort(vec.begin(),vec.end()); 69 int ans=0; 70 for(int i=0; i<vec.size(); ++i) { 71 int x=vec[i]; 72 if(lcp(x,rnk[ia])==la)ans+=suf[sa[x]+max(i?lcp(x,vec[i-1]):0,max(la-1,lb-1))]; 73 } 74 printf("%d ",ans); 75 return 0; 76 }