Loj #2479. 「九省联考 2018」制胡窜
题目描述
对于一个字符串 (S),我们定义 (|S|) 表示 (S) 的长度。
接着,我们定义 (S_i) 表示 (S) 中第 (i) 个字符,(S_{L,R}) 表示由 (S) 中从左往右数,第 (L) 个字符到第 (R) 个字符依次连接形成的字符串。特别的,如果 (L > R) ,或者 (L < [1, |S|]), 或者 (R < [1, |S|]) 我们可以认为 (S_{L,R}) 为空串。
给定一个长度为 (n) 的仅由数字构成的字符串 (S),现在有 (q) 次询问,第 (k) 次询问会给出 (S) 的一个字符串 (S_{l,r}) ,请你求出有多少对 ((i, j)),满足 (1 le i < j le n),(i + 1 lt j),且 (S_{l,r}) 出现在 (S_{1,i}) 中或 (S_{i+1, j−1}) 中或 (S_{j,n}) 中。
输入格式
输入的第一行包含两个整数 (n, q)。
第二行包含一个长度为 (n) 的仅由数字构成的字符串 (S)。
接下来 (q) 行,每行两个正整数 (l) 和 (r),表示此次询问的子串是 (S_{l,r})。
输出格式
对于每个询问,输出一个整数表示合法的数对个数。
数据范围与提示
对于所有测试数据,(1 le n le 10^5),(1 le q le 3 · 10^5),(1 le l le r le n)。
(\)
感觉这道题细节贼烦人,正式考试的话估计可以刚一整场。
首先建后缀自动机,然后在使用线段树合并维护(endpos)集合。
询问的时候就先在(fail)树上倍增找到给定字符串出现的节点。然后我们将合法的((i,j))二元组分为以下三种情况:
- (S_{1,i})中出现
- (S_{1,i})中未出现,(S_{j,n})中出现
- (S_{1,i},S_{j,n})中为出现,(S_{i+1,j-1})中出现。
前两种情况很好算,找到位置最靠前以及最靠后的(endpos)就行了。
下面来考虑第三种情况。假设最靠前的(endpos)是(L),最靠后的是(R),字符串长度为(len)。显然(i<L,j>R-len+1)。
我们先考虑一种暴力做法:枚举(jin[R-len+2,n]),然后算对于每个(j)有多少个可行的(i)。设(<j)的最大的(endpos)为(mx),显然可行的(i)只与(mx)有关,为(min{L,mx-len})。
理解了这个暴力做法过后正解就差不多知道了。对于线段树上每个节点,我们令每个位置的权值为其左边第一个(endpos)(如果没有则为(0)),(sum)为这些位置的权值和,(rmax)为最右边的(endpos),(lempty)为左边有多少个位置没有(endpos)。注意上述的信息只考虑了线段树所表示的区间,区间外的(endpos)不对其产生任何影响。正因为如此,在询问的时候先遍历左儿子,动态更新最右边的(endpos),再遍历右儿子计算答案。
道理很简单,就是要注意的边界情况有点多。。。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 200005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
int n,m;
char s[N];
int fail[N<<1],mxlen[N<<1];
int ch[N<<1][10];
int last=1,cnt=1;
int pos[N<<1],id[N<<1];
ll ss[N];
void Insert(int f,int P) {
int p=last;
int v=++cnt;
pos[v]=P;
id[P]=v;
last=v;
mxlen[v]=mxlen[p]+1;
while(p&&!ch[p][f]) ch[p][f]=v,p=fail[p];
if(!p) return fail[v]=1,void();
int sn=ch[p][f];
if(mxlen[sn]==mxlen[p]+1) return fail[v]=sn,void();
int New=++cnt;
mxlen[New]=mxlen[p]+1;
memcpy(ch[New],ch[sn],sizeof(ch[sn]));
fail[New]=fail[sn];
fail[sn]=fail[v]=New;
while(p&&ch[p][f]==sn) ch[p][f]=New,p=fail[p];
}
int fa[N<<1][20];
vector<int>e[N<<1];
int rt[N<<1];
int ls[N*50],rs[N*50];
int tag[N*50];
int emp[N*50],rmax[N*50];
ll sum[N*50];
int tot;
int lx,rx;
void update(int v,int lx,int rx) {
sum[v]=sum[ls[v]]+sum[rs[v]];
int mid=lx+rx>>1;
ll R=rs[v]?emp[rs[v]]:rx-mid;
sum[v]+=1ll*rmax[ls[v]]*R;
if(!ls[v]||emp[ls[v]]==mid-lx+1) {
emp[v]=mid-lx+1+R;
} else {
emp[v]=emp[ls[v]];
}
if(rs[v]) rmax[v]=rmax[rs[v]];
else rmax[v]=rmax[ls[v]];
}
void Insert(int &v,int lx,int rx,int p) {
v=++tot;
tag[v]=1;
if(lx==rx) {
sum[v]=p;
rmax[v]=lx;
return ;
}
int mid=lx+rx>>1;
if(p<=mid) Insert(ls[v],lx,mid,p);
else Insert(rs[v],mid+1,rx,p);
update(v,lx,rx);
}
int Merge(int a,int b,int lx,int rx) {
if(!a||!b) return a+b;
int v=++tot;
int mid=lx+rx>>1;
ls[v]=Merge(ls[a],ls[b],lx,mid);
rs[v]=Merge(rs[a],rs[b],mid+1,rx);
update(v,lx,rx);
return v;
}
void dfs(int v) {
for(int i=1;i<=18;i++) fa[v][i]=fa[fa[v][i-1]][i-1];
if(pos[v]) Insert(rt[v],lx,rx,pos[v]);
for(int i=0;i<e[v].size();i++) {
int to=e[v][i];
dfs(to);
rt[v]=Merge(rt[v],rt[to],lx,rx);
}
}
int Find(int l,int r) {
int v=id[r];
for(int i=18;i>=0;i--)
if(fa[v][i]&&mxlen[fa[v][i]]>=r-l+1)
v=fa[v][i];
return v;
}
int query_mn(int v,int lx,int rx,int lim) {
if(!v||rx<lim) return 0;
if(lx==rx) return lx;
int mid=lx+rx>>1;
int x=query_mn(ls[v],lx,mid,lim);
if(x) return x;
else return query_mn(rs[v],mid+1,rx,lim);
}
int query_mx(int v,int lx,int rx) {
if(lx==rx) return lx;
int mid=lx+rx>>1;
if(rs[v]) return query_mx(rs[v],mid+1,rx);
else return query_mx(ls[v],lx,mid);
}
ll query_s(int v,int lx,int rx,int l,int r,int &L) {
if(lx>r) return 0;
if(rx<l) {
L=max(L,rmax[v]);
return 0;
}
if(l<=lx&&rx<=r) {
ll x=!v?rx-lx+1:emp[v];
ll ans=sum[v]+1ll*x*L;
L=max(L,rmax[v]);
return ans;
}
int mid=lx+rx>>1;
return query_s(ls[v],lx,mid,l,r,L)+query_s(rs[v],mid+1,rx,l,r,L);
}
ll solve(int v,int len) {
int mn=query_mn(rt[v],lx,rx,1),mx=query_mx(rt[v],lx,rx);
ll ans=0;
if(mn==mx) {
if(mx<n) ans+=ss[n-mx-1];
if(mn-len+1>1) ans+=ss[mn-len-1];
ans+=1ll*(n-mx)*(mn-len);
return ans;
}
if(mn<n) ans+=ss[n-mn-1];
if(mx-len+1>1) ans+=ss[mx-len-1];
if(mx-len+1>mn+1) ans-=ss[mx-len-mn];
int ed=query_mn(rt[v],lx,rx,mn+len-1);
if(ed) {
ed=max(ed,mx-len+1);
ans+=1ll*(n-ed)*(mn-1);
ed--;
} else ed=n-1;
int st=max(mn,mx-len+1);
if(ed>=st) {
int L=0;
ans+=query_s(rt[v],lx,rx,st,ed,L);
ans-=1ll*len*(ed-st+1);
}
return ans;
}
int main() {
n=Get(),m=Get();
for(int i=1;i<=n;i++) ss[i]=ss[i-1]+i;
lx=1,rx=n;
scanf("%s",s+1);
for(int i=1;i<=n;i++) Insert(s[i]-'0',i);
for(int i=2;i<=cnt;i++) {
e[fail[i]].push_back(i);
fa[i][0]=fail[i];
}
dfs(1);
int l,r;
while(m--) {
l=Get(),r=Get();
cout<<solve(Find(l,r),r-l+1)<<"
";
}
return 0;
}