后缀数组+单调栈的应用
首先我们研究一下这个表达式,可以发现前半部分与串的情况并没有关系,而只是跟串的长度有关,所以我们先把前半部分算出来:
于是我们只需计算出即可
那么可以发现,对于排名分别为i,j的两个串,他们的lcp应当是:
但是这里的时间复杂度仍然很大
我们换一个角度来思考:如果设,那么我们认为height[k]产生了一个贡献
所以我们可以从每一个height[k]产生了多少贡献的角度来思考,这样就可以把时间复杂度降到O(n)
不难发现,一个k会对一个区间产生贡献的条件就是height[k]是所在区间的最小值
这就可以用单调栈维护了!!
但是要注意,为了防止重复计算,我们对单调栈的两端点的取等条件设成不一样的(即左侧算到第一个height小于等于height[k],右侧算到第一个height小于height[k]的位置)
这样找到每个点向左和向右能延伸的位置lx,rx这样他所占的区间个数就是(i-lx)*(rx-i)
这样去更新就可以了
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ll long long
using namespace std;
char s[500005];
int sa[500005];
int rank[500005];
int temprank[500005];
int height[500005];
int has[500005];
int v[500005];
int lx[500005],rx[500005];
int l;
bool be_same(int x,int y,int len)
{
return x+len>l||y+len>l||rank[x]!=rank[y]||rank[x+len]!=rank[y+len];
}
void get_sa()
{
int cnt=0;
for(int i=1;i<=l;i++)v[i]=s[i];
for(int i=1;i<=l;i++)has[v[i]]++;
for(int i=0;i<128;i++)if(has[i])temprank[i]=++cnt;
for(int i=1;i<128;i++)has[i]+=has[i-1];
for(int i=1;i<=l;i++)
{
rank[i]=temprank[v[i]];
sa[has[v[i]]--]=i;
}
for(int k=1;cnt!=l;k<<=1)
{
cnt=0;
for(int i=1;i<=l;i++)has[i]=0;
for(int i=1;i<=l;i++)has[rank[i]]++;
for(int i=1;i<=l;i++)has[i]+=has[i-1];
for(int i=l;i;i--)if(sa[i]>k)temprank[sa[i]-k]=has[rank[sa[i]-k]]--;
for(int i=1;i<=k;i++)temprank[l-i+1]=has[rank[l-i+1]]--;
for(int i=1;i<=l;i++)sa[temprank[i]]=i;
for(int i=1;i<=l;i++)temprank[sa[i]]=be_same(sa[i],sa[i-1],k)?++cnt:cnt;
for(int i=1;i<=l;i++)rank[i]=temprank[i];
}
for(int i=1;i<=l;i++)
{
if(rank[i]==1)continue;
int j=max(1,height[rank[i-1]]-1);
while(s[i+j-1]==s[sa[rank[i]-1]+j-1])height[rank[i]]=j++;
}
}
void init()
{
height[0]=height[l+1]=-0x3f3f3f3f;
ll ret=0;
for(int i=1;i<=l;i++)lx[i]=i-1,rx[i]=i+1;
for(int i=2;i<=l;i++)while(height[lx[i]]>height[i])lx[i]=lx[lx[i]];
for(int i=l;i>=2;i--)while(height[rx[i]]>=height[i])rx[i]=rx[rx[i]];
for(int i=2;i<=l;i++)ret+=2*height[i]*(ll)((ll)(rx[i]-i)*(ll)(i-lx[i]));
ll ans=1ll*(l-1)*l/2ll*(l+1);
printf("%lld
",ans-ret);
}
int main()
{
scanf("%s",s+1);
l=strlen(s+1);
get_sa();
init();
return 0;
}