我们考虑如果我们能快速的得出一条路线上的字符串组成的字典树,那么问题就迎刃而解了。开太多的字典树,开不下,我们可持久化以下就好了。 持久化出根到每个结点的字典树,然后ans(a) + ans(b) - 2 * ans(lca)即可。
可持久化字典树应该如何操作呢,我们考虑,对于一个字典树x,我们向其加入一个字符串s1,得到一棵新的字典树y,两个字典树绝大多数部分均相同,只有遍历s1的这条路径上,会有所差异。所以我们考虑,对于字典树y,我们直接把和x相同的部分,直接指向x,不另行新建。这样子,每次版本更新,我们只会新建len(加入字符串长度)个的节点
1 #include <cstdio> 2 #include <cstring> 3 #include <cmath> 4 #include <algorithm> 5 #include <stack> 6 using namespace std; 7 const int MAXN = 1000005,MAXM = 1000005; 8 int cnt,n,tot,q; 9 int head[MAXN],rt[MAXN],dep[MAXN],siz[MAXN],to[MAXM],nxt[MAXM]; 10 int p[MAXN][30]; 11 int ch[MAXN][26]; 12 char str[MAXM][15]; 13 void add(int x,int y,char *s) 14 { 15 nxt[++cnt] = head[x]; 16 to[cnt] = y; 17 head[x] = cnt; 18 memcpy(str[cnt],s,sizeof(str[cnt])); 19 } 20 void dfs(int x,int frm) 21 { 22 for (int i = head[x];i;i = nxt[i]) 23 { 24 if (to[i] == frm) continue; 25 p[to[i]][0] = x; 26 dep[to[i]] = dep[x] + 1; 27 int u = rt[x],v = rt[to[i]] = ++tot,lenn = strlen(str[i] + 1); 28 for (int j = 1;j <= lenn;j++) 29 { 30 for (int o = 0;o < 26;o++) 31 { 32 siz[v] += siz[ch[u][o]]; 33 ch[v][o] = ch[u][o]; 34 } 35 siz[v]++; 36 u = ch[u][str[i][j] - 'a']; 37 ch[v][str[i][j] - 'a'] = ++tot; 38 v = tot; 39 } 40 siz[v]++; 41 dfs(to[i],x); 42 } 43 } 44 int lca(int x,int y) 45 { 46 int jqe = log2(n); 47 if (dep[x] < dep[y]) swap(x,y); 48 for (int i = jqe;i >= 0;i--) 49 if (dep[x] - (1 << i) >= dep[y]) 50 x = p[x][i]; 51 if (x == y) return x; 52 for (int i = jqe;i >= 0;i--) 53 if (p[x][i] != p[y][i]) x = p[x][i],y = p[y][i]; 54 return p[x][0]; 55 } 56 void lca_init() 57 { 58 int jqe = log2(n); 59 for (int i = 1;i <= jqe;i++) 60 for (int j = 1;j <= n;j++) 61 p[j][i] = p[p[j][i - 1]][i - 1]; 62 } 63 int solve(int x,char *s) 64 { 65 int u = rt[x],lenn = strlen(s + 1); 66 for (int i = 1;i <= lenn;i++) 67 u = ch[u][s[i] - 'a']; 68 return siz[u]; 69 } 70 int main() 71 { 72 scanf("%d",&n); 73 for (int i = 1;i <= n - 1;i++) 74 { 75 int u,v; 76 char s[15]; 77 scanf("%d%d%s",&u,&v,s + 1); 78 add(u,v,s); 79 add(v,u,s); 80 } 81 dep[1] = 1; 82 dfs(1,0); 83 lca_init(); 84 scanf("%d",&q); 85 for (int i = 1;i <= q;i++) 86 { 87 int u,v; 88 char s[15]; 89 scanf("%d%d%s",&u,&v,s + 1); 90 int t = lca(u,v); 91 printf("%d ",solve(u,s) + solve(v,s) - 2 * solve(t,s)); 92 } 93 return 0; 94 } 95 ?