用 (kmp) 暴力处理经过 (lca) 的匹配,这一部分复杂度为 (O(sum|s|))。然后就只用考虑直上直下的链的匹配,离线后对询问串建 (AC) 自动机,在原树上遍历时加入贡献,答案差分统计,即长链的匹配减去短链的匹配,用树状数组维护 (fail) 树子树和即可。
#include<bits/stdc++.h>
#define maxn 300010
#define lowbit(x) (x&(-x))
using namespace std;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,root;
int ans[maxn],f[maxn][19],dep[maxn],nxt[maxn];
char s[maxn],col[maxn],str[maxn],tmp[maxn];
struct node
{
int p,v,id;
node(int a=0,int b=0,int c=0)
{
p=a,v=b,id=c;
}
};
vector<node> q1[maxn],q2[maxn];
struct edge
{
int to,nxt;
char v;
edge(int a=0,int b=0,char c=0)
{
to=a,nxt=b,v=c;
}
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to,char val)
{
e[++edge_cnt]=edge(to,head[from],val),head[from]=edge_cnt;
}
struct AC
{
int tot,cnt;
int ch[maxn][28],fail[maxn],in[maxn],out[maxn],tr[maxn];
vector<int> ve[maxn];
void update(int x,int v)
{
if(!x) return;
x=in[x];
while(x<=n) tr[x]+=v,x+=lowbit(x);
}
int ask(int x)
{
int v=0;
while(x) v+=tr[x],x-=lowbit(x);
return v;
}
int query(int x)
{
return ask(out[x])-ask(in[x]-1);
}
int insert(int type=0)
{
int p=root,len=strlen(s+1);
if(type) reverse(s+1,s+len+1);
for(int i=1;i<=len;++i)
{
int c=s[i]-'a';
if(!ch[p][c]) ch[p][c]=++tot;
p=ch[p][c];
}
if(type) reverse(s+1,s+len+1);
return p;
}
void dfs_dfn(int x)
{
in[x]=++cnt;
for(int i=0;i<ve[x].size();++i) dfs_dfn(ve[x][i]);
out[x]=cnt;
}
void build()
{
queue<int> q;
for(int c=0;c<26;++c)
if(ch[root][c])
q.push(ch[root][c]);
while(!q.empty())
{
int x=q.front();
q.pop();
for(int c=0;c<26;++c)
{
int y=ch[x][c];
if(y) fail[y]=ch[fail[x]][c],q.push(y);
else ch[x][c]=ch[fail[x]][c];
}
}
for(int i=1;i<=tot;++i) ve[fail[i]].push_back(i);
dfs_dfn(root);
}
}A,B;
void dfs_pre(int x,int fa)
{
dep[x]=dep[f[x][0]=fa]+1;
for(int i=1;i<=17;++i) f[x][i]=f[f[x][i-1]][i-1];
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa) continue;
col[y]=e[i].v,dfs_pre(y,x);
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=17;i>=0;--i)
if(f[x][i]&&dep[f[x][i]]>=dep[y])
x=f[x][i];
if(x==y) return x;
for(int i=17;i>=0;--i)
if(f[x][i]&&f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
int get(int x,int k)
{
for(int i=0;i<=17;++i)
if((k>>i)&1)
x=f[x][i];
return x;
}
void work(int x,int y,int id)
{
int anc=lca(x,y),p1=A.insert(),p2=B.insert(1),len=strlen(s+1),p,cnt1=0,cnt2=0,pos=0;
p=get(x,max(dep[x]-dep[anc]-len+1,0));
q2[x].push_back(node(p2,1,id)),q2[p].push_back(node(p2,-1,id));
while(p!=anc) str[++cnt1]=col[p],p=f[p][0];
p=get(y,max(dep[y]-dep[anc]-len+1,0));
q1[y].push_back(node(p1,1,id)),q1[p].push_back(node(p1,-1,id));
while(p!=anc) tmp[++cnt2]=col[p],p=f[p][0];
for(int i=cnt2;i;--i) str[++cnt1]=tmp[i];
for(int i=1;i<=len;++i) nxt[i]=0;
for(int i=2;i<=len;++i)
{
while(pos&&s[pos+1]!=s[i]) pos=nxt[pos];
nxt[i]=(pos+=s[pos+1]==s[i]);
}
pos=0;
for(int i=1;i<=cnt1;++i)
{
while(pos&&s[pos+1]!=str[i]) pos=nxt[pos];
pos+=s[pos+1]==str[i];
if(pos==len) ans[id]++,pos=nxt[pos];
}
}
void dfs_ans(int x,int p1,int p2)
{
A.update(p1,1),B.update(p2,1);
for(int i=0;i<q1[x].size();++i)
ans[q1[x][i].id]+=A.query(q1[x][i].p)*q1[x][i].v;
for(int i=0;i<q2[x].size();++i)
ans[q2[x][i].id]+=B.query(q2[x][i].p)*q2[x][i].v;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to,v=e[i].v-'a';
if(y==f[x][0]) continue;
dfs_ans(y,A.ch[p1][v],B.ch[p2][v]);
}
A.update(p1,-1),B.update(p2,-1);
}
int main()
{
read(n),read(m);
for(int i=1;i<n;++i)
{
int x,y;
read(x),read(y),scanf("%s",s),add(x,y,s[0]),add(y,x,s[0]);
}
dfs_pre(1,0);
for(int i=1;i<=m;++i)
{
int x,y;
read(x),read(y),scanf("%s",s+1);
if(x!=y) work(x,y,i);
}
A.build(),B.build(),dfs_ans(1,root,root);
for(int i=1;i<=m;++i) printf("%d
",ans[i]);
return 0;
}