Common Substrings
题意
给出两个字符串,求他俩长度>=k的公共子串的数量。
思路
(n^2) 的思路比较容易想到。
我们把两个字符串用一个没有出现过的字符隔开拼接起来,做后缀数组。
那么公共子串的数量,就是A串的后缀和B串的后缀之间的所有最长公共前缀和。
统计时,遍历(height)数组。
对于第i个后缀,遍历(j(j<i)),计算((i,j))的最长公共前缀
如果后缀(i)是A串中的,那么(j)就只统计B串,如果后缀(i)是B串,(j)只统计A串。
优化:
我们知道(lcp(i,j)=min(height[i+1]...height[j]))
根据这个可以知道对于后缀(i)统计的答案:
(lcp(i,1) ,lcp(i,2) , lcp(i,3), ... , lcp(i,i-1))
是非递减的。
知道这一点,先统计后缀(i)为B串中的后缀的答案。
维护一个(sum),表示后缀(i)与其前面A串的最长公共前缀和。(后缀(i)
不一定是B串中的后缀)
假如现在已经前6个后缀更新完了,贡献如下:
1 2 3 4 5
而(height[7]==2),那么对于后缀7来说,贡献就应该为:
1 2 2 2 2 2
此时我们维护一个单调递增栈和一个数组(num),单调栈中存放
上述贡献值,(num)数组存放栈内的贡献值出现的次数。
每次更新后缀的时候,遍历栈中贡献值>=(height[i])的,
(sum) 减去 贡献值变小的部分* 贡献值数量,将(height[i])入栈
如果此时后缀(i)为B串中的后缀,(ans+=sum)
然后统计为后缀(i)为A串的。
PS:
为什么要加一个字符隔开?
两个串分别为:aaaaa , aaaa;
如果不隔开,A串中的后缀5,和B串中的后缀1的公共前缀长度就是4,
但应该是1,所以要隔开。
代码
/*Gts2m ranks first in the world*/
#define pb push_back
#define stop system("pause")
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<algorithm>
//#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
typedef long long ll;
typedef unsigned long long ull;
char s[N],t[N];
int sa[N],rk[N],ht[N],oldrk[N],pos[N],cnt[N];
int n,m,lens,lent;
bool cmp(int a,int b,int k)
{
return oldrk[a]==oldrk[b]&&oldrk[a+k]==oldrk[b+k];
}
void getsa()
{
m=122;
memset(cnt,0,sizeof(cnt));
for(int i=1; i<=n; i++) ++cnt[rk[i]=s[i]];
for(int i=1; i<=m; i++) cnt[i]+=cnt[i-1];
for(int i=n; i; i--) sa[cnt[rk[i]]--]=i;
for(int k=1; k<=n; k<<=1)
{
int num=0;
for(int i=n-k+1; i<=n; i++) pos[++num]=i;
for(int i=1; i<=n; i++) if(sa[i]>k) pos[++num]=sa[i]-k;
memset(cnt,0,sizeof(cnt));
for(int i=1; i<=n; i++) ++cnt[rk[i]];
for(int i=1; i<=m; i++) cnt[i]+=cnt[i-1];
for(int i=n; i; i--) sa[cnt[rk[pos[i]]]--]=pos[i];
num=0;
memcpy(oldrk,rk,sizeof(rk));
for(int i=1; i<=n; i++) rk[sa[i]]=cmp(sa[i],sa[i-1],k)?num:++num;
if(num==n) break;
m=num;
}
for(int i=1; i<=n; i++) rk[sa[i]]=i;
int k=0;
for(int i=1; i<=n; i++)
{
if(k) --k;
while(s[i+k]==s[sa[rk[i]-1]+k]) ++k;
ht[rk[i]]=k;
}
}
int sta[N],num[N];
int main()
{
int k;
while(~scanf("%d",&k)&&k)
{
scanf("%s%s",s+1,t+1);
lens=strlen(s+1),lent=strlen(t+1);
n=lens;
s[++n]='$';
for(int i=1; i<=lent; i++) s[++n]=t[i];
s[++n]='%';
getsa();
ll ans=0,sum=0;
int top=0;
for(int i=1; i<=n; i++)
{
if(ht[i]<k)
sum=0,top=0;
else
{
int cnt=0;//cnt 表示应该等于ht[i]的贡献值的个数
if(sa[i-1]<=lens+1)//后缀sa[i-1]是A串时
{
cnt=1;
sum+=ht[i]-k+1;
}
while(top&&ht[i]<=ht[sta[top]])
{
cnt+=num[top];
sum-=1LL*num[top]*(ht[sta[top]]-ht[i]);//更新sum值
top--;
}
sta[++top]=i,num[top]=cnt;
if(sa[i]>lens+1)//当后缀sa[i]为B串时,更新答案
ans+=sum;
}
}
//统计后缀i为A串时的值
for(int i=1; i<=n; i++)
{
if(ht[i]<k)
sum=0,top=0;
else
{
int cnt=0;
if(sa[i-1]>lens+1)
{
cnt=1;
sum+=ht[i]-k+1;
}
while(top&&ht[i]<=ht[sta[top]])
{
cnt+=num[top];
sum-=1LL*num[top]*(ht[sta[top]]-ht[i]);
top--;
}
sta[++top]=i,num[top]=cnt;
if(sa[i]<=lens+1)
ans+=sum;
}
}
printf("%lld
",ans);
}
return 0;
}