(\)
Description
- (nle 5 imes 10^5)
(\)
Solution
自已 YY 了一种跟黄学长不太一样的 (SA) 做法 ......
先考虑两个 (len) 求和最后的结果。
显然(suf(x)) 会作为 (j) 被算 (x-1) 次,作为 (i) 被算 (n-x) 次,总共 (n-1) 次。
所有后缀长度是等差数列,直接求和。
所求化简为
[(n-1) imes frac{n imes(n+1)}{2}-2 imes sum_{1le i<jle n} lcp(i,j)
]
问题在后面的东西怎么求。
(\)
考虑我们求两个串 (i,j (rank[i]<rank[j])) 的 (lcp) 怎么做。
答案是 (min_{k=i+1}^j{height[k]}) 。
注意到原来两个串构成的区间长度至少为 (2) ,而因为计算 (lcp) 时是左开右闭,所以最后询问的就是:
[(n-1) imes frac{n imes(n+1)}{2}-2 imes sum_{i=1}^nsum_{j=i}^nmin_{k=i}^j{height[k]}
]
也就是说我们需要求 (height) 数列所有区间 (min) 的和。
(\)
黄学长的写法是,考虑每一个位置作为最小值所控制的区间,向左向右分别单调栈搞一下。
自己 yy 出了一个扫一遍的做法,也是单调栈。
栈内每个元素维护两个东西,一个是当前块代表的最小值,另一个是当前块覆盖的区间。
考虑每次新来一个长度时计算以该位置结尾的答案,如果这个位置大于前面的所有数,答案显然直接继承上一个位置每一个区间的答案,然后再加上当前位置长度为 (1) 的区间的答案。
在考虑一般的情况。显然一个位置只能替换掉后面一段的区间,所以单调栈维护的时候记录当前位置覆盖的区间数即可。
快速查询栈内总和比较简单,注意弹栈的时候需要在总和里去掉原来的答案。
(\)
Code
#include<cmath>
#include<cstdio>
#include<cctype>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 500010
#define R register
using namespace std;
typedef long long ll;
char ss[N];
ll ans,h[N],s[N],sa[N],cnt[N],t1[N],t2[N];
inline void da(ll n,ll m){
ll *x=t1,*y=t2;
s[n++]=0;
for(R ll i=0;i<m;++i) cnt[i]=0;
for(R ll i=0;i<n;++i) ++cnt[x[i]=s[i]];
for(R ll i=1;i<m;++i) cnt[i]+=cnt[i-1];
for(R ll i=n-1;~i;--i) sa[--cnt[x[i]]]=i;
for(R ll k=1,p=0;p<n&&k<=n;k<<=1,m=p){
p=0;
for(R ll i=n-k;i<n;++i) y[p++]=i;
for(R ll i=0;i<n;++i) if(sa[i]>=k) y[p++]=sa[i]-k;
for(R ll i=0;i<m;++i) cnt[i]=0;
for(R ll i=0;i<n;++i) ++cnt[x[y[i]]];
for(R ll i=1;i<m;++i) cnt[i]+=cnt[i-1];
for(R ll i=n-1;~i;--i) sa[--cnt[x[y[i]]]]=y[i];
swap(x,y); p=1; x[sa[0]]=0;
for(R ll i=1;i<n;++i) x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?p-1:p++;
}
--n; h[0]=0;
for(R ll i=0;i<n;++i) sa[i]=sa[i+1];
for(R ll i=0;i<n;++i) x[sa[i]]=i;
for(R ll i=0,p=0;i<n;++i){
if(!x[i]) continue;
if(p) --p;
while(s[i+p]==s[sa[x[i]-1]+p]) ++p;
h[x[i]]=p;
}
}
struct s{ll x,cnt;}stk[N];
ll top=0;
int main(){
scanf("%s",ss);
ll n=strlen(ss);
for(R ll i=0;i<n;++i) s[i]=ss[i];
da(n,256);
ans=n*(n-1)*(n+1)/2;
for(R ll i=0,cnt=1,sum=0;i<n;++i,cnt=1){
while(top&&stk[top].x>=h[i]){
cnt+=stk[top].cnt;
sum-=stk[top].x*stk[top].cnt;
--top;
}
stk[++top].x=h[i];
stk[top].cnt=cnt;
sum+=stk[top].x*stk[top].cnt;
ans-=sum*2;
}
printf("%lld
",ans);
}