建出 SAM。 对于每个子串,处理出这个子串最长合法的后缀。
#include<bits/stdc++.h> #define N 200007 using namespace std; vector<int> v[N]; int last=1,tot=1,ch[N][26],mx[N],len[N],q,p,fa[N],nq; char s[N],g[N]; void Sam(int x){ q=++tot; len[q]=len[last]+1; for (;last&&!ch[last][x];last=fa[last]) ch[last][x]=q; if (!last) fa[q]=1; else{ int p=ch[last][x]; if (len[last]+1==len[p]) fa[q]=p; else { nq=++tot; len[nq]=len[last]+1; memcpy(ch[nq],ch[p],sizeof ch[p]); fa[nq]=fa[p]; fa[q]=fa[p]=nq; for (;last&&ch[last][x]==p;last=fa[last]) ch[last][x]=nq; } } last=q; } int Len,now,sum[N],l,k,ma[N]; void Lxf(int x){ for (auto i:v[x]) Lxf(i),mx[x]=max(mx[i],mx[x]); } long long nb; signed main () { scanf("%s",s+1); scanf("%s",g+1); Len=strlen(s+1); for (int i=1;i<=Len;i++) { Sam(s[i]-'a'); sum[i]=sum[i-1]+(g[i]=='0'); } now=1; scanf("%d",&k); for (int i=1;i<=Len;i++) { now=ch[now][s[i]-'a']; while (sum[i]-sum[p]>k) p++; mx[now]=max(mx[now],i-p); } for (int i=1;i<=tot;i++) v[fa[i]].push_back(i); Lxf(1); for (int i=1;i<=tot;i++) nb+=max(0,min(mx[i],len[i])-len[fa[i]]); printf("%lld ",nb); }