首先看到k只有50,那么就可以开一个数组预处理出来。
sum[u][k]表示节点u到根节点所有节点深度的k次方和,dfs一遍就都搞出来了,预处理复杂度O(n * 50)(快速幂复杂度不计了)。
查询就是lca复杂度,对于路径(x, y),令z = lca(x, y),则ans(x, y) = sum[x][k] + sum[y][k] - sum[z][k] - sum[fa[z]][k]。注意取模别出负数。
1 #include<cstdio> 2 #include<iostream> 3 #include<cmath> 4 #include<algorithm> 5 #include<cstring> 6 #include<cstdlib> 7 #include<cctype> 8 #include<vector> 9 #include<stack> 10 #include<queue> 11 using namespace std; 12 #define enter puts("") 13 #define space putchar(' ') 14 #define Mem(a, x) memset(a, x, sizeof(a)) 15 #define rg register 16 typedef long long ll; 17 typedef double db; 18 const int INF = 0x3f3f3f3f; 19 const db eps = 1e-8; 20 const int maxn = 3e5 + 5; 21 const ll mod = 998244353; 22 inline ll read() 23 { 24 ll ans = 0; 25 char ch = getchar(), last = ' '; 26 while(!isdigit(ch)) {last = ch; ch = getchar();} 27 while(isdigit(ch)) {ans = ans * 10 + ch - '0'; ch = getchar();} 28 if(last == '-') ans = -ans; 29 return ans; 30 } 31 inline void write(ll x) 32 { 33 if(x < 0) x = -x, putchar('-'); 34 if(x >= 10) write(x / 10); 35 putchar(x % 10 + '0'); 36 } 37 38 int n, m; 39 struct Edge 40 { 41 int nxt, to; 42 }e[maxn << 1]; 43 int head[maxn], ecnt = -1; 44 void addEdge(int x, int y) 45 { 46 e[++ecnt] = (Edge){head[x], y}; 47 head[x] = ecnt; 48 } 49 50 ll quickpow(ll a, int b) 51 { 52 a %= mod; 53 ll ret = 1; 54 for(; b; a = a * a % mod, b >>= 1) 55 if(b & 1) ret = ret * a % mod; 56 return ret; 57 } 58 59 int dis[maxn]; 60 ll sum[maxn][55]; 61 void dfs(int now, int f) 62 { 63 for(int i = head[now]; i != -1; i = e[i].nxt) 64 { 65 if(e[i].to == f) continue; 66 dis[e[i].to] = dis[now] + 1; 67 for(int j = 1; j <= 50; ++j) sum[e[i].to][j] = (sum[now][j] + quickpow(dis[e[i].to], j)) % mod; 68 dfs(e[i].to, now); 69 } 70 } 71 72 const int N = 22; 73 int fa[maxn][25]; 74 void dfs2(int now, int f) 75 { 76 for(int i = 1; i <= N; ++i) 77 fa[now][i] = fa[fa[now][i - 1]][i - 1]; 78 for(int i = head[now]; i != -1; i = e[i].nxt) 79 { 80 if(e[i].to == f) continue; 81 fa[e[i].to][0] = now; 82 dfs2(e[i].to, now); 83 } 84 } 85 int lca(int x, int y) 86 { 87 if(dis[x] < dis[y]) swap(x, y); 88 for(int i = N; i >= 0; --i) 89 if(dis[x] - (1 << i) >= dis[y]) x = fa[x][i]; 90 if(x == y) return x; 91 for(int i = N; i >= 0; --i) 92 if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i]; 93 return fa[x][0]; 94 } 95 96 ll solve(int x, int y, int k) 97 { 98 int z = lca(x, y); 99 return (sum[x][k] + sum[y][k] - (sum[z][k] + sum[fa[z][0]][k]) % mod + mod) % mod; 100 } 101 102 int main() 103 { 104 Mem(head, -1); 105 n = read(); 106 for(int i = 1; i < n; ++i) 107 { 108 int x = read(), y = read(); 109 addEdge(x, y); addEdge(y, x); 110 } 111 dfs(1, 0); 112 dfs2(1, 0); 113 m = read(); 114 for(int i = 1; i <= m; ++i) 115 { 116 int x = read(), y = read(), k = read(); 117 write(solve(x, y, k)), enter; 118 } 119 return 0; 120 }