写过阿狸的打字机应该就可以写这道题了…对S中的n个串构建AC自动机和fail树,然后每新来一个T中的串,就把这个串扔进AC自动机里走一遍,会经过一些节点,每个节点在fail树上到根的路径上的节点对应的串都在这个串里出现,那么我们把这些节点到根节点的路径的并上的每个节点都+1,那么按节点的dfs序排序,每个节点处+1,相邻节点(不包括第一个和最后一个)处-1即可.倍增LCA并不会被卡.如果写树剖LCA会更快一点.
#include<cstdio> #include<cstring> #include<queue> #include<algorithm> using namespace std; struct node{ node* ch[26],*fail; int num; node(int x){num=x;fail=0;memset(ch,0,sizeof(ch));} }*root;int tot=0; const int maxn=3000005; int pos[maxn]; void Add(char *c,int x){ node* p=root; while(*c){ int t=(*c)-'a'; if(p->ch[t]==NULL)p->ch[t]=new node(++tot); p=p->ch[t];++c; } pos[x]=p->num; } char str[maxn]; struct edge{ int to,next; }lst[maxn];int len=1,first[maxn]; void addedge(int a,int b){ lst[len].to=b;lst[len].next=first[a];first[a]=len++; } int point[maxn],dfn[maxn],sz[maxn],depth[maxn],p[maxn][22],T; void dfs(int x){ sz[x]=1;dfn[x]=++T; for(int j=0;p[x][j];++j)p[x][j+1]=p[p[x][j]][j]; for(int pt=first[x];pt;pt=lst[pt].next){ depth[lst[pt].to]=depth[x]+1;p[lst[pt].to][0]=x; dfs(lst[pt].to); sz[x]+=sz[lst[pt].to]; } } int lca(int u,int v){ if(depth[u]<depth[v])swap(u,v); for(int j=21;j>=0;--j){ if(depth[p[u][j]]>=depth[v])u=p[u][j]; } if(u==v)return u; for(int j=21;j>=0;--j){ if(p[u][j]!=p[v][j])u=p[u][j],v=p[v][j]; } return p[v][0]; } void getfail(){ queue<node*> q;q.push(root); while(!q.empty()){ node* x=q.front();q.pop(); for(int i=0;i<26;++i){ if(x->ch[i]){ if(x==root)x->ch[i]->fail=root; else x->ch[i]->fail=x->fail->ch[i]; q.push(x->ch[i]); addedge(x->ch[i]->fail->num,x->ch[i]->num); }else if(x==root)x->ch[i]=root; else x->ch[i]=x->fail->ch[i]; } } } int seq[maxn],cnt; int c[maxn]; void bit_add(int x){ for(;x<=tot;x+=x&(-x))c[x]++; } void bit_del(int x){ for(;x<=tot;x+=x&(-x))c[x]--; } int bit_sum(int x){ int ans=0;for(;x;x-=x&(-x))ans+=c[x];return ans; } int main(){ int n;scanf("%d",&n); root= new node(++tot); for(int i=1;i<=n;++i){ scanf("%s",str); Add(str,i); } getfail(); depth[1]=1;dfs(1); for(int i=1;i<=tot;++i)point[dfn[i]]=i; int q;scanf("%d",&q); int typ,x; while(q--){ scanf("%d",&typ); if(typ==1){ scanf("%s",str); node* p=root;cnt=0; for(int i=0;str[i];++i){ p=p->ch[str[i]-'a'];seq[cnt++]=dfn[p->num]; } sort(seq,seq+cnt); bit_add(seq[0]); for(int i=1;i<cnt;++i){ if(seq[i]!=seq[i-1]){ bit_add(seq[i]);bit_del(dfn[lca(point[seq[i]],point[seq[i-1]])]); } } }else{ scanf("%d",&x);x=pos[x]; printf("%d ",bit_sum(dfn[x]+sz[x]-1)-bit_sum(dfn[x]-1)); } } return 0; }
#include<cstdio> #include<cstring> #include<queue> #include<algorithm> using namespace std; struct node{ node* ch[26],*fail; int num; node(int x){num=x;fail=0;memset(ch,0,sizeof(ch));} }*root;int tot=0; const int maxn=2000005; int pos[maxn]; void Add(char *c,int x){ node* p=root; while(*c){ int t=(*c)-'a'; if(p->ch[t]==NULL)p->ch[t]=new node(++tot); p=p->ch[t];++c; } pos[x]=p->num; } char str[maxn]; struct edge{ int to,next; }lst[maxn];int len=1,first[maxn]; void addedge(int a,int b){ lst[len].to=b;lst[len].next=first[a];first[a]=len++; } int point[maxn],dfn[maxn],sz[maxn],depth[maxn],hvy[maxn],prt[maxn],top[maxn],T; void dfs1(int x){ sz[x]=1; for(int pt=first[x];pt;pt=lst[pt].next){ depth[lst[pt].to]=depth[x]+1;prt[lst[pt].to]=x; dfs1(lst[pt].to); sz[x]+=sz[lst[pt].to]; if(sz[lst[pt].to]>sz[hvy[x]])hvy[x]=lst[pt].to; } } void dfs2(int x){ dfn[x]=++T;point[T]=x; if(hvy[x]){ top[hvy[x]]=top[x]; dfs2(hvy[x]); for(int pt=first[x];pt;pt=lst[pt].next){ if(lst[pt].to!=hvy[x]){ top[lst[pt].to]=lst[pt].to; dfs2(lst[pt].to); } } } } int lca(int u,int v){ int t1=top[u],t2=top[v]; while(t1!=t2){ if(depth[t1]<depth[t2]){ swap(t1,t2);swap(u,v); } u=prt[t1];t1=top[u]; } return depth[u]<depth[v]?u:v; } void getfail(){ queue<node*> q;q.push(root); while(!q.empty()){ node* x=q.front();q.pop(); for(int i=0;i<26;++i){ if(x->ch[i]){ if(x==root)x->ch[i]->fail=root; else x->ch[i]->fail=x->fail->ch[i]; q.push(x->ch[i]); addedge(x->ch[i]->fail->num,x->ch[i]->num); }else if(x==root)x->ch[i]=root; else x->ch[i]=x->fail->ch[i]; } } } int seq[maxn],cnt; int c[maxn]; void bit_add(int x){ for(;x<=tot;x+=x&(-x))c[x]++; } void bit_del(int x){ for(;x<=tot;x+=x&(-x))c[x]--; } int bit_sum(int x){ int ans=0;for(;x;x-=x&(-x))ans+=c[x];return ans; } int main(){ int n;scanf("%d",&n); root= new node(++tot); for(int i=1;i<=n;++i){ scanf("%s",str); Add(str,i); } getfail(); depth[1]=1;dfs1(1);top[1]=1;dfs2(1); for(int i=1;i<=tot;++i)point[dfn[i]]=i; int q;scanf("%d",&q); int typ,x; while(q--){ scanf("%d",&typ); if(typ==1){ scanf("%s",str); node* p=root;cnt=0; for(int i=0;str[i];++i){ p=p->ch[str[i]-'a'];seq[cnt++]=dfn[p->num]; } sort(seq,seq+cnt); bit_add(seq[0]); for(int i=1;i<cnt;++i){ if(seq[i]!=seq[i-1]){ bit_add(seq[i]);bit_del(dfn[lca(point[seq[i]],point[seq[i-1]])]); } } }else{ scanf("%d",&x);x=pos[x]; printf("%d ",bit_sum(dfn[x]+sz[x]-1)-bit_sum(dfn[x]-1)); } } return 0; }