题目
题目链接:https://www.luogu.com.cn/problem/P5405
小刘同学是一个喜欢氪金手游的男孩子。
他最近迷上了一个新游戏,游戏的内容就是不断地抽卡。现在已知:
- 卡池里总共有 (N) 种卡,第 (i) 种卡有一个权值 (W_i),小刘同学不知道 (W_i) 具体的值是什么。但是他通过和网友交流,他了解到 (W_i) 服从一个分布。
- 具体地,对每个 (i),小刘了解到三个参数 (p_{i,1},p_{i,2},p_{i,3}),(W_i) 将会以 (p_{i,j}) 的概率取值为 (j),保证 (p_{i,1}+p_{i,2}+p_{i,3}=1)。
小刘开始玩游戏了,他每次会氪一元钱来抽一张卡,其中抽到卡 (i) 的概率为:
小刘会不停地抽卡,直到他手里集齐了全部 (N) 种卡。
抽卡结束之后,服务器记录下来了小刘第一次得到每张卡的时间 (T_i)。游戏公司在这里设置了一个彩蛋:公司准备了 (N-1) 个二元组 ((u_i,v_i)),如果对任意的 (i),成立 (T_{u_i}<T_{v_i}),那么游戏公司就会认为小刘是极其幸运的,从而送给他一个橱柜的手办作为幸运大奖。
游戏公司为了降低获奖概率,它准备的这些 ((u_i,v_i)) 满足这样一个性质:对于任意的 (varnothing
e Ssubsetneq{1,2,ldots,N}),总能找到 ((u_i,v_i)) 满足:(u_iin S,v_i
otin S) 或者 (u_i
otin S,v_iin S)。
请你求出小刘同学能够得到幸运大奖的概率,可以保证结果是一个有理数,请输出它对 (998244353) 取模的结果。
思路
题目中给出了关于所有二元组的性质。如果我们把二元组 ((u_i,v_i)) 看作一条从 (u_i) 向 (v_i) 的有向边,那么其实等价于连边后会形成一棵树(无视边的方向)。
先考虑如果这棵树是一棵外向树怎么办。不妨设 (1) 为根,对于一个点 (x),它被抽到的时间必须小于它子树内的点被抽到的时间,它子树外的点与它无关。记 (sum[x]) 表示 (x) 子树内所有点的 (w) 之和,这种情况的概率为
后面那一坨东西等比数列求和一下,发现原式等于 (frac{w_x}{sum[x]})。
设 (f[x][i]) 表示点 (x) 的子树内,所有点的 (w) 之和为 (i) 的情况下,满足子树内所有条件的期望。
那么合并 (x) 和它的一棵子树 (y) 时只需要树上背包枚举到子树大小就可以做到 (O(n^2)) 了。
当这棵树不是一棵外向树时,考虑把反向边的贡献容斥掉。
那么合并 (x) 的一个子树 (y) 时,如果 (x) 与 (y) 的连边是 (y o x) 的,如果算 (y) 子树的贡献,那么就需要加上 (y) 子树内 (w) 之和,容斥系数乘上 (-1);如果不算 (y) 子树的贡献,那么就不加 (y) 子树内 (w) 之和,容斥系数也不用乘。
预处理逆元可以做到 (O(n^2))。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1010,MOD=998244353;
int n,tot,a[N][4],head[N],siz[N];
ll ans,inv[N],f[N][N*3],g[N*3];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
ll fpow(ll x,ll k)
{
ll ans=1;
for (;k;k>>=1,x=x*x%MOD)
if (k&1) ans=ans*x%MOD;
return ans;
}
void dfs(int x,int fa)
{
ll res=fpow(a[x][1]+a[x][2]+a[x][3],MOD-2);
f[x][1]=1LL*a[x][1]*res%MOD;
f[x][2]=2LL*a[x][2]*res%MOD;
f[x][3]=3LL*a[x][3]*res%MOD;
siz[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs(v,x);
for (int j=1;j<=siz[x]*3;j++)
g[j]=f[x][j],f[x][j]=0;
for (int j=1;j<=siz[x]*3;j++)
for (int k=1;k<=siz[v]*3;k++)
{
f[x][j+k]=(f[x][j+k]+((i&1)?1LL:-1LL)*g[j]*f[v][k])%MOD;
if (!(i&1)) f[x][j]=(f[x][j]+1LL*g[j]*f[v][k])%MOD;
}
siz[x]+=siz[v];
}
}
for (int i=1;i<=siz[x]*3;i++)
f[x][i]=f[x][i]*inv[i]%MOD;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1;i<=n;i++)
scanf("%d%d%d",&a[i][1],&a[i][2],&a[i][3]);
for (int i=1;i<=n*3;i++) inv[i]=fpow(i,MOD-2);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs(1,0);
for (int i=1;i<=n*3;i++)
ans=(ans+f[1][i])%MOD;
cout<<(ans+MOD)%MOD;
return 0;
}