[HAOI2016] 找相同字符
Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。 (n,m le 200000)
Solution
将两个字符串按序连接,中间用一个其它字符隔开,新串记为 (S) ,那么 (S[1 , n]) 为第一个字符串的对应部分, (S[n+2 ,n+m+1]) 为第二个字符串对应的部分。预处理出 (S) 的后缀数组和高度数组。
我们将这些后缀中起始位置在 ([1,n]) 内的称为黑后缀, ([n+2,n+m+1]) 内的成为白后缀。那么我们可以考虑对每个白后缀它与所有黑后缀匹配的答案,这个贡献就是它与所有黑后缀的 (LCP) 长度的和。
假设所有后缀的顺序是按后缀排序的,不难想到分为左边的黑后缀和右边的黑后缀两部分,分开处理。
对于所黑后缀在白后缀左边的答案,我们可以按照顺序扫描所有的白后缀,同时维护到当前位置为止,当前位置串与前面任意一个位置串的 (LCP) 长度,同时记录它们的和。不难发现这可以用一个单调栈来处理。具体地,我们维护一个单调递增的栈,在栈中需要记录每个元素的下标,配合一个描述黑白后缀的前缀和数组,这样可以快速更新答案。
每当我们扫描到一个白后缀,就把单调栈中的和加进答案。
同理对所有黑后缀在白后缀右边的答案也去这样处理一遍。得到的就是最终的结果。
实现层面上这个题还是挺水了(虽然还是挠了半天毛)
Code
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 500005;
#define reset(x) memset(x,0,sizeof x)
namespace SA
{
int n,m=256,sa[N],y[N],u[N],v[N],o[N],r[N],h[N],T;
char str[N];
void solve()
{
n=strlen(str+1);
for(int i=1; i<=n; i++)
u[str[i]]++;
for(int i=1; i<=m; i++)
u[i]+=u[i-1];
for(int i=n; i>=1; i--)
sa[u[str[i]]--]=i;
r[sa[1]]=1;
for(int i=2; i<=n; i++)
r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]);
for(int l=1; r[sa[n]]<n; l<<=1)
{
memset(u,0,sizeof u);
memset(v,0,sizeof v);
memcpy(o,r,sizeof r);
for(int i=1; i<=n; i++)
u[r[i]]++, v[r[i+l]]++;
for(int i=1; i<=n; i++)
u[i]+=u[i-1], v[i]+=v[i-1];
for(int i=n; i>=1; i--)
y[v[r[i+l]]--]=i;
for(int i=n; i>=1; i--)
sa[u[r[y[i]]]--]=y[i];
r[sa[1]]=1;
for(int i=2; i<=n; i++)
r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l]));
}
{
int i,j,k=0;
for(int i=1; i<=n; h[r[i++]]=k)
for(k?k--:0,j=sa[r[i]-1]; str[i+k]==str[j+k]; k++);
}
}
}
namespace MS
{
int a[N],b[N],c[N],p,ans;
void init()
{
reset(a);
reset(b);
reset(c);
p=0;
ans=0;
}
void push(int x,int y)
{
while(x<=a[p]&&p)
{
ans-=(c[b[p]]-c[b[p-1]])*a[p];
--p;
}
a[++p]=x;
b[p]=y;
ans+=(c[b[p]]-c[b[p-1]])*x;
}
}
char a[N],b[N];
int n,m,t1,t2,t3,t4,ans;
signed main()
{
cin>>a+1>>b+1;
n=strlen(a+1);
m=strlen(b+1);
reset(SA::str);
for(int i=1; i<=n; i++)
SA::str[i]=a[i];
SA::str[n+1]='$';
for(int i=1; i<=m; i++)
SA::str[n+i+1]=b[i];
SA::solve();
for(int i=1; i<=n+m+1; i++)
MS::c[i+1]=(SA::sa[i]<=n?1:0);
for(int i=1; i<=n+m+1; i++)
MS::c[i]+=MS::c[i-1];
for(int i=1; i<=n+m+1; i++)
{
MS::push(SA::h[i],i);
if(SA::sa[i]<=n+1)
continue;
ans+=MS::ans;
}
MS::init();
for(int i=1; i<=n+m+1; i++)
MS::c[i+1]=(SA::sa[i]>n+1?1:0);
for(int i=1; i<=n+m+1; i++)
MS::c[i]+=MS::c[i-1];
for(int i=1; i<=n+m+1; i++)
{
MS::push(SA::h[i],i);
if(SA::sa[i]>n)
continue;
ans+=MS::ans;
}
cout<<ans<<endl;
}