Description
给定一棵 (n) 个结点的树,你从点 (x) 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 (Q) 次询问,每次询问给定一个集合 (S) ,求如果从 (x) 出发一直随机游走,直到点集 (S) 中所有点都至少经过一次的话,期望游走几步。
特别地,点 (x)(即起点)视为一开始就被经过了一次。
答案对 (998244353) 取模。
Solution
不妨设 (f_{i,S}) 表示在点 (i) 时,要遍历集合 (S) 的期望步数。那么对于一个询问 (S) ,答案就是 (f_{x,S}) 。
从两个方面来考虑如何求 (f) :
- 如果 (u otin S) ,由套路,显然满足 [f_{u,S}=frac{sum_{ ext{v is the neighbor of u}}f_{v,S}}{degree_u}+1]
- 如果 (uin S)
- 若 ({u}=S) ,显然 (f_{u,S}=0) ;
- 若 ({u} eq S) ,容易得到 (f_{u,S}=f_{u,S-{u}})
这样我们对于同一个状态 (S) 可以得到若干个方程,那么在这一个状态内高斯消元即可。
由于是树上消元,所以可以用[Codeforces 802L]Send the Fool Further! (hard)的方法化成 (f_u=k_uf_{fa_u}+b_u) 的形式 (O(n)) 求解。
总复杂度是 (O(nlog(n)2^n+Q)) ,其中 (log(n)) 是求逆元的复杂度。
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 20, SIZE = (1<<18)+5, yzh = 998244353;
int n, q, x, u, v, bin[N], dg[N], S;
struct tt {int to, next; }edge[N<<1];
int path[N], top, k[N], b[N], f[N][SIZE];
int quick_pow(int a, int b) {
int ans = 1;
while (b) {
if (b&1) ans = 1ll*ans*a%yzh;
b >>= 1, a = 1ll*a*a%yzh;
}
return ans;
}
void dfs(int u, int fa) {
k[u] = b[u] = 0;
for (int i = path[u], v; i; i = edge[i].next)
if ((v = edge[i].to) != fa) dfs(v, u);
if (!(bin[u-1]&S)) {
if (dg[u] == 1 && x != u) k[u] = b[u] = 1;
else {
k[u] = dg[u], b[u] = dg[u];
for (int i = path[u], v; i; i = edge[i].next)
if ((v = edge[i].to) != fa) {
(k[u] -= k[v]) %= yzh; (b[u] += b[v]) %= yzh;
}
k[u] = quick_pow(k[u], yzh-2);
b[u] = 1ll*b[u]*k[u]%yzh;
}
}else {
if (S^bin[u-1]) {
k[u] = 0; b[u] = f[u][S^bin[u-1]];
}else k[u] = b[u] = 0;
}
}
void cal(int u, int fa) {
f[u][S] = (1ll*k[u]*f[fa][S]%yzh+b[u])%yzh;
for (int i = path[u], v; i; i = edge[i].next)
if ((v = edge[i].to) != fa) cal(v, u);
}
void add(int u, int v) {edge[++top] = (tt){v, path[u]}, path[u] = top; ++dg[v]; }
void work() {
scanf("%d%d%d", &n, &q, &x);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
bin[0] = 1; for (int i = 1; i < N; i++) bin[i] = (bin[i-1]<<1);
for (int i = 1; i < bin[n]; i++) S = i, dfs(x, 0), cal(x, 0);
while (q--) {
S = 0; scanf("%d", &u);
for (int i = 1; i <= u; i++) scanf("%d", &v), S |= bin[v-1];
printf("%d
", (f[x][S]+yzh)%yzh);
}
}
int main() {work(); return 0; }