题面
解析
求到$S$集合每个点走一次的期望,即求$E(max(S))$,套上$Min-Max$容斥,即是求$E(min(T)),Tsubseteq S$
考虑对每种集合做一次$dp$,外层枚举$P subseteq U$,$dp[u]$表示点$u$到$P$集合内任意一点的期望时间, $deg[u]$表示点$u$的度数,$fa$为$u$的父亲节点,$v$为$u$的儿子节点。
若$u in P$,则$dp[u] = 0$
否则有:$dp[u] = frac{1}{deg[u]}(dp[fa]+sum dp[v])+1$
然后是一个我没总结过的较常见套路,设$dp[u] = A[u] * dp[fa] + B[u]$,带入上式化简:$$deg[u]*dp[u]=dp[fa]+sum(A[v]*dp[u]+B[v])+deg[u]$$$$(deg[u]-sum A[v])*dp[u]=dp[fa]+(sum B[v])+deg[u]$$$$dp[u]=frac{1}{deg[u]-sum A[v]}*dp[fa]+frac{deg[u]+sum B[v]}{deg[u]-sum A[v]}$$
故:$$A[u]=frac{1}{deg[u]-sum A[v]}, B[u]=frac{deg[u]+sum B[v]}{deg[u]-sum A[v]}$$
对于根节点,由于其没有父节点,故$dp[u]=B[u]$,也即$B[u]$为所求。
现在可以求出$E(min(T))$,但我们需要求出$sum_{Tsubseteq S}(-1)^{|T|+1}*E(min(T))$,可以发现其实就是求$S$的子集权值和, 可以用$FWT(or)$预处理出所有$S$的答案,每次询问可以做到$O(1)$回答。
因$DP$过程中需要求逆元,故时间复杂度为$DP$的时间复杂度:$O(n2^n log mod)$
代码:
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<vector> using namespace std; typedef long long ll; const int maxn = (1 << 18) + 5, mod = 998244353; ll qpow(ll x, ll y) { ll ret = 1; while(y) { if(y&1) ret = ret * x % mod; x = x * x % mod; y >>= 1; } return ret; } ll add(ll x, ll y) { return x + y < mod? x + y: x + y - mod; } ll rdc(ll x, ll y) { return x - y < 0? x - y + mod: x - y; } int n, m, Q, rt, deg[20], num[maxn]; ll f[maxn], A[20], B[20]; vector<int> G[maxn]; void dfs(int x, int fa, int s) { if((s >> (x - 1)) & 1) { A[x] = B[x] = 0; return ; } ll s1 = 0, s2 = 0; for(auto &id: G[x]) { if(id == fa) continue; dfs(id, x, s); s1 = add(s1, A[id]); s2 = add(s2, B[id]); } A[x] = qpow(rdc(deg[x], s1), mod - 2); B[x] = add(s2, deg[x]) * A[x] % mod; } void FWT(ll *x) { for(int i = 1; i <= m; i <<= 1) for(int j = 0; j <= m; j += (i << 1)) for(int k = 0; k < i; ++k) x[i+j+k] = add(x[i+j+k], x[j+k]); } int main() { scanf("%d%d%d", &n, &Q, &rt); int u, v, cnt, sta; for(int i = 1; i < n; ++i) { scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); ++ deg[u]; ++ deg[v]; } m = (1 << n) - 1; for(int i = 1; i <= m; ++i) { dfs(rt, 0, i); num[i] = num[i>>1] + (i & 1); f[i] = ((num[i] & 1)? 1: mod - 1) * B[rt] % mod; } FWT(f); for(int i = 1; i <= Q; ++i) { scanf("%d", &cnt); sta = 0; for(int j = 1; j <= cnt; ++j) { scanf("%d", &u); sta |= (1 << (u - 1)); } printf("%lld ", f[sta]); } return 0; }