Loj #3124. 「CTS2019 | CTSC2019」氪金手游
题目描述
小刘同学是一个喜欢氪金手游的男孩子。
他最近迷上了一个新游戏,游戏的内容就是不断地抽卡。现在已知:
- 卡池里总共有 (N) 种卡,第 (i) 种卡有一个权值 (W_i),小刘同学不知道 (W_i) 具体的值是什么。但是他通过和网友交流,他了解到 (W_i) 服从一个分布。
- 具体地,对每个 (i),小刘了解到三个参数 (p_{i,1},p_{i,2},p_{i,3}),(W_i) 将会以 (p_{i,j}) 的概率取值为 (j),保证 (p_{i,1}+p_{i,2}+p_{i,3}=1)。
小刘开始玩游戏了,他每次会氪一元钱来抽一张卡,其中抽到卡 (i) 的概率为:
小刘会不停地抽卡,直到他手里集齐了全部 (N) 种卡。
抽卡结束之后,服务器记录下来了小刘第一次得到每张卡的时间 (T_i)。游戏公司在这里设置了一个彩蛋:公司准备了 (N−1) 个二元组 ((u_i,v_i)),如果对任意的 (i),成立 (T_{u_i}<T_{v_i}),那么游戏公司就会认为小刘是极其幸运的,从而送给他一个橱柜的手办作为幸运大奖。
游戏公司为了降低获奖概率,它准备的这些 ((u_i,v_i)) 满足这样一个性质:对于任意的 (varnothing e Ssubsetneq{1,2,ldots,N}),总能找到 ((u_i,v_i)) 满足:(u_iin S,v_i otin S) 或者 (u_i otin S,v_iin S)。
请你求出小刘同学能够得到幸运大奖的概率,可以保证结果是一个有理数,请输出它对 (998244353) 取模的结果。
输入格式
第一行一个整数 (N),表示卡的种类数。
接下来 (N) 行,每行三个整数 (a_{i,1},a_{i,2},a_{i,3}),而题目给出的 (p_{i,j}=frac{a_{i,j}}{a_{i,1}+a_{i,2}+a_{i,3}})。
接下来 (N−1) 行,每行两个整数 (u_i,v_i),描述一个二元组(意义见题目描述)。
输出格式
输出一行一个整数,表示所求概率对 (998244353) 取模的结果。
数据范围与提示
对于全部的测试数据,保证 (Nle 1000),(a_{i,j}le 10^6)。
- (20) 分的数据,(Nle 15)。
- (15) 分的数据,(Nle 200),且每个限制保证 (|u_i−v_i|=1)。
- (20) 分的数据,(Nle 1000),且每个限制保证 (|u_i−v_i|=1)。
- (15) 分的数据,(Nle 200)。
- (30) 分的数据,无特殊限制。
首先给出的关系构成了一棵树。
随便选一个点作为根。假设所有点的(w)以及确定,先考虑如果所有边都是由父亲连向儿子的情况,那么答案就是
其中(size_i)表示(i)的子树中(w)之和。考虑第(i)个点要成为(i)的子树中第一个被抽到的概率,枚举抽到(i)子树外的卡片的次数得到:
从这个式子也可以看出每个点的概率是独立的。
所以我们就可以(DP),设(f_{i,j})表示以(i)为根的子树中,(size_i=j)的概率。
对于由儿子连向父亲的边我们用容斥原理来处理。我们可以枚举一些边反向,另一些边消失,算出概率,假设(k)条边反向,那么容斥系数就是((-1)^k)。(DP)的时候遇到这种边就讨论一下是消失还是反向,带上容斥系数就行了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 1005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
const ll mod=998244353;
ll ksm(ll t,ll x) {
ll ans=1;
for(;x;x>>=1,t=t*t%mod)
if(x&1) ans=ans*t%mod;
return ans;
}
int n;
ll rate[N][4],a[N][4];
struct road {
int to,nxt;
int dir;
}s[N<<1];
int h[N],cnt;
void add(int i,int j,int d) {s[++cnt]=(road) {j,h[i],d};h[i]=cnt;}
int fa[N];
ll f[N][N*3];
ll tem[N*3];
ll inv[N*3];
int size[N];
void dfs(int v) {
f[v][0]=1;
for(int i=h[v];i;i=s[i].nxt) {
int to=s[i].to;
if(to==fa[v]) continue ;
fa[to]=v;
dfs(to);
for(int j=0;j<=(size[v]+size[to])*3;j++) tem[j]=0;
if(!s[i].dir) {
for(int j=0;j<=size[v]*3;j++) {
for(int k=0;k<=size[to]*3;k++) {
(tem[j+k]+=f[v][j]*f[to][k])%=mod;
}
}
} else {
ll sum=0;
for(int j=0;j<=size[to]*3;j++) (sum+=f[to][j])%=mod;
for(int j=0;j<=size[v]*3;j++) tem[j]=f[v][j]*sum%mod;
for(int j=0;j<=size[v]*3;j++) {
for(int k=0;k<=size[to]*3;k++) {
(tem[j+k]+=mod-f[v][j]*f[to][k]%mod)%=mod;
}
}
}
size[v]+=size[to];
for(int j=0;j<=size[v]*3;j++) f[v][j]=tem[j];
}
size[v]++;
for(int i=0;i<=size[v]*3;i++) tem[i]=0;
for(int i=0;i<=size[v]*3;i++) {
for(int j=1;j<=3;j++) {
(tem[i+j]+=f[v][i]*rate[v][j]%mod*j*inv[i+j])%=mod;
}
}
for(int i=0;i<=size[v]*3;i++) f[v][i]=tem[i];
}
int main() {
n=Get();
inv[0]=1;
for(int i=1;i<=n*3;i++) inv[i]=ksm(i,mod-2);
for(int i=1;i<=n;i++) {
for(int j=1;j<=3;j++) a[i][j]=Get();
ll sum=a[i][1]+a[i][2]+a[i][3];
for(int j=1;j<=3;j++) rate[i][j]=a[i][j]*ksm(sum,mod-2)%mod;
}
for(int i=1;i<n;i++) {
int a=Get(),b=Get();
add(a,b,0),add(b,a,1);
}
dfs(1);
ll ans=0;
for(int i=0;i<=size[1]*3;i++) (ans+=f[1][i])%=mod;
cout<<ans;
return 0;
}