题目大意:求$sum_{1leq i<j leq N} suf_{i}+suf_{j}-2cdot lcp(suf_{i},suf_{j})$
先是后缀数组打错了,又是把+=打成了=,我是zz
我的做法比较奇葩..
转化式子,原式=$sum_{i=1}^{n-1}(i+1)cdot i-sum_{1leq i<j leq N}2cdot lcp(suf_{i},suf_{j})$
这样计算后面的部分就行了
首先用$sa$预处理出$height$数组
对后缀进行排序后,对于某个一个后缀$suf_{i}$,如果另一个后缀$suf_{j}$和它的$lcp$长度是$x$,必须要保证$forall ;kin[i+1,j-1],h_{k}geq x$
用一个单调栈维护$height$,设$num_{tp}$表示栈中$lcp$长度为$L_{tp}$的后缀数量总和
用一个动态的数$sum$记录当前栈中的$h_{k}*num_{k}$总和
每遍历到一个排名$i$,因为 排名$<i$的后缀 和 排名$>i$的后缀 的$lcp$最长是$h_{i}$
故先删去栈中大于$h_{i}$的元素,并记录一共删掉了多少元素
在$i$之后,要去掉对于排名在i之后的后缀的无效长度,所有$h$大于$h_{i}$的部分都要修改成$h_{i}$,即$sum-=(L_{tp}-h_{i})cdot num_{tp}$
再把$suf_{i}$推入栈中,接着把删掉的元素数量加回来
最终答案就是每次统计完成后的$sum$总和
1 #include <bitset> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 #define N1 505000 6 #define ll long long 7 #define inf 0x3f3f3f3f 8 #define rint register int 9 using namespace std; 10 11 12 int len; 13 int gch(char *str) 14 { 15 char c=getchar(); 16 while(c<'a'||c>'z'){c=getchar();} 17 while(c>='a'&&c<='z'){str[++len]=c;c=getchar();} 18 } 19 int gint() 20 { 21 int ret=0,fh=1;char c=getchar(); 22 while(c<'0'||c>'9'){if(c=='-')fh=-1;c=getchar();} 23 while(c>='0'&&c<='9'){ret=ret*10+c-'0';c=getchar();} 24 return ret*fh; 25 } 26 char str[N1]; 27 int rk[N1],tr[N1],sa[N1],hs[N1],h[N1]; 28 int check(int i,int j,int k){ 29 if(i+k>len||j+k>len) return 0; 30 return (rk[i]==rk[j]&&rk[i+k]==rk[j+k])?1:0;} 31 void get_sa() 32 { 33 rint i,cnt=0; 34 for(i=1;i<=len;i++) hs[str[i]]++; 35 for(i=1;i<=128;i++) if(hs[i]) tr[i]=++cnt; 36 for(i=1;i<=128;i++) hs[i]+=hs[i-1]; 37 for(i=1;i<=len;i++) rk[i]=tr[str[i]],sa[hs[str[i]]--]=i; 38 for(int k=1;cnt<len;k<<=1) 39 { 40 for(i=1;i<=cnt;i++) hs[i]=0; 41 for(i=1;i<=len;i++) hs[rk[i]]++; 42 for(i=1;i<=cnt;i++) hs[i]+=hs[i-1]; 43 for(i=len;i>=1;i--) if(sa[i]>k) tr[sa[i]-k]=hs[rk[sa[i]-k]]--; 44 for(i=1;i<=k;i++) tr[len-i+1]=hs[rk[len-i+1]]--; 45 for(i=1;i<=len;i++) sa[tr[i]]=i; 46 for(i=1,cnt=0;i<=len;i++) tr[sa[i]]=check(sa[i],sa[i-1],k)?cnt:++cnt; 47 for(i=1;i<=len;i++) rk[i]=tr[i]; 48 } 49 for(i=1;i<=len;i++){ 50 if(rk[i]==1) continue; 51 for(int j=max(1,h[rk[i-1]]-1);;j++) 52 if(str[i+j-1]==str[sa[rk[i]-1]+j-1]) h[rk[i]]=j; 53 else break; 54 } 55 } 56 int stk[N1],num[N1],L[N1],tp; 57 ll sum; 58 ll solve() 59 { 60 ll ans=0,tmp;tp=0; 61 for(int i=2;i<=len;i++) 62 { 63 tmp=0; 64 while(tp>0&&L[tp]>h[i]){ 65 tmp+=num[tp]; 66 sum-=1ll*(L[tp]-h[i])*num[tp]; 67 L[tp]=0,num[tp]=0,tp--; 68 } 69 if(h[i]>L[tp]) 70 tp++,L[tp]=h[i]; 71 num[tp]+=tmp+1; 72 sum+=h[i],ans+=sum; 73 } 74 return ans; 75 } 76 77 int main() 78 { 79 gch(str); 80 get_sa(); 81 ll ans=0; 82 for(int i=1;i<=len-1;i++) 83 ans+=1ll*(i+1)*i; 84 ans=ans/2*3; 85 printf("%lld ",ans-2ll*solve()); 86 return 0; 87 }