看到题目的第一反应当然是暴力:对于串s建后缀自动机,每次询问中,求w对应的子串在s的SAM中的right集合。O(qmk)听上去显然过不了。
数据范围有个∑w<=1e5,也就是说,q*k<=1e5,当q更小或k更小时可以用不同的方法。
k更小时,会发现每个w的子串数可能会很小,子串数可能还没m多。这个时候将m个[li,ri]的询问可能会有重复的,所以可以用vector存一下对于每一个[L,R],满足li=L且ri=R的询问的编号是哪些。求出w对于每个[L,R]有多少个满足li=L且ri=R的询问的编号在[a,b](可以用vector自带的lower_bound),和w[L,R]在s中的出现次数。这样,时间复杂度是O(q*log m*k2)。
q更小时,对于每个w可以求出它的每个[1,i]的前缀会匹配到s的SAM的哪个点(记作pla[i])、能匹配s多长。SAM中的每个点的fail指针指向的点都是它的后缀。所以对于每个w[li,ri],可以先走到pla[ri],再倍增地顺着fail指针走。时间复杂度是O(q*m*log k)。
听说代码很难调?
#include <bits/stdc++.h> #define rep(i,x,y) for(register int i=(x);i<=(y);++i) #define dwn(i,x,y) for(register int i=(x);i>=(y);--i) #define re register #define LL long long #define maxn 200010 #define block 333 using namespace std; inline LL read() { LL x=0,f=1; char ch=getchar(); while(isdigit(ch)==0 && ch!='-')ch=getchar(); if(ch=='-')f=-1,ch=getchar(); while(isdigit(ch))x=(x<<3)+(x<<1)+ch-'0',ch=getchar(); return x*f; } inline void write(LL x) { LL f=0;char ch[20]; if(!x){puts("0");return;} if(x<0){putchar('-');x=-x;} while(x)ch[++f]=x%10+'0',x/=10; while(f)putchar(ch[f--]); putchar(' '); } int ch[maxn][30],dis[maxn],ord[maxn],c[maxn],fa[maxn],cnt,rt,lst,anc[maxn][21]; int n,m,q,k,ql[maxn],qr[maxn],pla[maxn],lth[maxn]; LL r[maxn]; vector<int>to[block][block]; char s[maxn],w[maxn]; int gx(char c){return c-'a';} void go(int & u,int & len, int x) { while(!ch[u][x]&&u)u=fa[u],len=dis[u]; if(ch[u][x])u=ch[u][x],++len; else u=rt,len=0; } void extend(LL pos) { int x=gx(s[pos]),p=lst,np=++cnt;lst=np;dis[np]=pos; for(;p&&!ch[p][x];p=fa[p])ch[p][x]=np; if(!p)fa[np]=rt; else { LL q=ch[p][x]; if(dis[q]==dis[p]+1)fa[np]=q; else { LL nq=++cnt;dis[nq]=dis[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[q])); fa[nq]=fa[q],fa[q]=fa[np]=nq; for(;p&&ch[p][x]==q;p=fa[p])ch[p][x]=nq; } } } void getr() { for(int u=rt,i=1;i<=n;++i)u=ch[u][gx(s[i])],++r[u]; rep(i,1,cnt)c[dis[i]]++; rep(i,1,n)c[i]+=c[i-1]; rep(i,1,cnt)ord[c[dis[i]]--]=i; dwn(i,cnt,1)r[fa[ord[i]]]+=r[ord[i]]; } void getfa(){rep(l,1,cnt){int i=ord[l];anc[i][0]=fa[i];rep(j,1,20)anc[i][j]=anc[anc[i][j-1]][j-1];}} int main() { lst=rt=++cnt; n=read(),m=read(),q=read(),k=read(); scanf("%s",s+1); rep(i,1,n)extend(i);getr(); rep(i,1,m){ql[i]=read()+1,qr[i]=read()+1;if(k<=block)to[ql[i]][qr[i]].push_back(i);} if(k<=block) { while(q--) { scanf("%s",w+1); int a=read()+1,b=read()+1;LL ans=0; rep(i,1,k) { int u=rt; rep(j,i,k) { if(ch[u][gx(w[j])]) { u=ch[u][gx(w[j])]; vector<int>::iterator L=lower_bound(to[i][j].begin(),to[i][j].end(),a); vector<int>::iterator R=upper_bound(to[i][j].begin(),to[i][j].end(),b); ans+=(R-L)*r[u]; } else break; } } write(ans); } } else { getfa(); while(q--) { scanf("%s",w+1); int a=read()+1,b=read()+1;LL ans=0; for(int len=0,u=rt,i=1;i<=k;i++)go(u,len,gx(w[i])),lth[i]=len,pla[i]=u; rep(i,a,b) { if(lth[qr[i]]<qr[i]-ql[i]+1)continue; else { int u=pla[qr[i]]; dwn(j,19,0)if(dis[anc[u][j]]>=qr[i]-ql[i]+1)u=anc[u][j]; ans+=r[u]; } } write(ans); } } return 0; }