直接把 r 加进去建虚树, 考虑虚树上的dp, 我们考虑虚树的dfs序的顺序dp过去。
dp[ i ][ j ] 表示到 i 这个点为止, 分成 j 组有多少种合法方案。
dp[ i ][ j ] = dp[ i - 1 ][ j ] * (j - have[ i ]) + dp[ i - 1 ][ j - 1 ], have[ i ] 表示 i 的祖先中有多少个在a中出现。
#include<bits/stdc++.h> using namespace std; const int N = (int)1e5 + 7; const int mod = (int)1e9 + 7; const int LOG = 17; int n, q, k, m, r, cas, a[N]; int depth[N], pa[N][LOG]; int in[N], ot[N], dfs_clock; int col[N], dp[301]; vector<int> G[N]; void dfs(int u, int fa) { in[u] = ++dfs_clock; depth[u] = depth[fa] + 1; pa[u][0] = fa; for(int i = 1; i < LOG; i++) { pa[u][i] = pa[pa[u][i - 1]][i - 1]; } for(auto &v : G[u]) { if(v == fa) continue; dfs(v, u); } ot[u] = dfs_clock; } inline int getLca(int u, int v) { if(depth[u] < depth[v]) swap(u, v); int d = depth[u] - depth[v]; for(int i = LOG - 1; i >= 0; i--) { if(d >> i & 1) { u = pa[u][i]; } } if(u == v) return u; for(int i = LOG - 1; i >= 0; i--) { if(pa[u][i] != pa[v][i]) { u = pa[u][i]; v = pa[v][i]; } } return pa[u][0]; } void go(int u, int fa, int have) { if(col[u] == cas) { for(int i = m; i >= 0; i--) { if(i < have + 1) dp[i] = 0; else { dp[i] = 1LL * dp[i] * (i - have) % mod + dp[i - 1]; if(dp[i] >= mod) dp[i] -= mod; } } } for(auto &v : G[u]) { if(v == fa) continue; go(v, u, have + (col[u] == cas)); } } int main() { scanf("%d%d", &n, &q); for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0); for(cas = 1; cas <= q; cas++) { vector<int> P; scanf("%d%d%d", &k, &m, &r); for(int i = 1; i <= k; i++) scanf("%d", &a[i]), col[a[i]] = cas; a[++k] = r; for(int i = 1; i <= k; i++) P.push_back(a[i]); sort(a + 1, a + 1 + k, [&](int x, int y) {return in[x] < in[y];}); for(int i = 1; i < k; i++) P.push_back(getLca(a[i], a[i + 1])); sort(P.begin(), P.end()); P.erase(unique(P.begin(), P.end()), P.end()); sort(P.begin(), P.end(), [&](int x, int y) {return in[x] < in[y];}); for(auto &t : P) G[t].clear(); vector<int> S; for(auto &t : P) { while(S.size() && ot[S.back()] < in[t]) S.pop_back(); if(S.size()) { G[S.back()].push_back(t); G[t].push_back(S.back()); } S.push_back(t); } for(int i = 0; i <= m; i++) dp[i] = (i == 0); go(r, 0, 0); int ans = 0; for(int i = 1; i <= m; i++) { ans += dp[i]; if(ans >= mod) ans -= mod; } printf("%d ", ans); } return 0; } /** **/