题意:
给出一个字典树
给出一个字符串
匹配的时候,如果在字典树上失配了,则回到根节点,从字符串的下一个位置开始匹配
给出q个询问
每次询问字符串区间[l,r],会在字典树上失配几次,最后停在字典树的哪个节点
求出字典树上的所有前缀哈希值,并标记这个哈希值对应的字典树上的节点
二分+上面的哈希值 求出以字符串的每个位置为起点,当失配的时候的下一次开始匹配的字符串位置
然后利用倍增求出 以每个位置为起点,失配2^i次的时候下一次开始匹配的字符串位置
对于每个查询[l,r],根据倍增的结果可以log的时间内求出失配次数
[最后一次失配地下一次开始匹配的字符串位置,r] 这段哈希值对应的字典树节点就是最后停留的节点
#include<map> #include<cstdio> #include<cstring> #include<iostream> //#include<ctime> #define N 100006 using namespace std; const int mod=10000019; int m; char s[N]; int tr[N][26],tot; int to[N][21]; typedef unsigned long long ULL; //map<ULL,int>mp,th; int mp[mod],th[mod]; ULL who[mod]; ULL hh[N],bit[N]; int bit2[21]; int yy; ULL yyy; void push(ULL val,int x,int tim) { yy=val%mod; while(th[yy]==tim) { ++yy; if(yy==mod) yy=0; } mp[yy]=x; who[yy]=val; th[yy]=tim; } void dfs(int x,ULL val,int tim) { //mp[val]=x; //th[val]=tim; //mp[yy]=x; //th[yy]=tim; push(val,x,tim); for(int i=0;i<26;++i) if(tr[x][i]) dfs(tr[x][i],val*27+i+1,tim); } int check(int l,int r,int tim) { if(r>m || l>r) return 0; yyy=hh[r]-hh[l-1]*bit[r-l+1]; yy=yyy%mod; if(th[yy]!=tim) return 0; while(th[yy]==tim) { if(who[yy]==yyy) return mp[yy]; ++yy; if(yy==mod) yy=0; } return 0; } int main() { // freopen("data2.txt","r",stdin); // freopen("11.txt","w",stdout); int T,n,q,fi; int l,r,mid,tmp; int p; int sum; char ss[3]; bit2[0]=1; for(int i=1;i<21;++i) bit2[i]=bit2[i-1]<<1; bit[0]=1; for(int i=1;i<N;++i) bit[i]=bit[i-1]*27; scanf("%d",&T); for(int t=1;t<=T;++t) { scanf("%d%d%d",&n,&m,&q); for(int i=1;i<=n;++i) { scanf("%d%s",&fi,ss); tr[fi][ss[0]-'a']=++tot; } scanf("%s",s+1); for(int i=1;i<=m;++i) hh[i]=hh[i-1]*27+(s[i]-'a'+1); // int t1=clock(); dfs(0,0,t); // printf("%.3lf ",(clock()-t1)/1000.0); // int t2=clock(); for(int i=1;i<=m;++i) { l=i; r=m+1; while(l<=r) { mid=l+r>>1; if(check(i,mid,t)) l=mid+1; else { tmp=mid; r=mid-1; } } to[i][0]=tmp+1; } // printf("%.3lf ",(clock()-t2)/1000.0); // int t3=clock(); p=0; while(bit2[p]<m) ++p; for(int i=0;i<=p;++i) to[m+1][i]=to[m+2][i]=m+2; for(int i=1;i<=p;++i) for(int j=1;j<=m;++j) to[j][i]=to[to[j][i-1]][i-1]; // printf("%.3lf ",(clock()-t3)/1000.0); while(q--) { scanf("%d%d",&l,&r); sum=0; for(int i=p;i>=0;--i) if(l<=r && to[l][i]<=r+1) { l=to[l][i]; sum+=bit2[i]; } check(l,r,t); printf("%d %d ",sum,check(l,r,t)); } for(int i=0;i<=tot;++i) memset(tr[i],0,sizeof(tr[i])); tot=0; } }