题意
n种卡,卡$i$有$p_{i,1},p_{i,2},p_{i,3}$的概率取权值$W_i=1,2,3$,抽一次卡获得卡$i$的概率是$frac{W_i}{sum_j W_j}$。
问:经过若干次抽卡并集齐所有卡后,第一次抽卡到每张卡的时间$T_i$满足给定的树形关系的概率(对于有向边$(u,v)$,满足$T_u < T_v$)。
考场上万万没想到容斥......QwQ
首先考虑一条链的情况$$1 ightarrow 2 ightarrow 3 ightarrow …… ightarrow n$$
$$P(1 ightarrow n)=prod_{i=1}^{n} frac{W_i}{sum_{j=i}^{n} W_j}$$
链上只有一条反向边时:
update:感谢 @JSOI爆零珂学家yzhang 指正:
由容斥原理:$$P=P(1 ightarrow k) imes P( k+1 ightarrow n) - P(1 ightarrow n)$$
将单向链推广到外向树,对于每个节点,由于是求第一次抽卡时间,各子树的抽卡时间只要满足子树内的先后关系,所以各子树之间概率是独立的,可以用dp转移。设$f[i][j]$表示$i$的子树权值和为$j$时状态合法的概率,像背包一样合并转移,复杂度$O(n^2)$。
考虑加入反向边,用容斥计算:$P(此边是反向边)=P(此边不存在)-P(此边是正向边)$
并不想优化常数
1 #include <cstdio>
2 #include <cstring>
3 #include <algorithm>
4 using namespace std;
5 typedef long long ll;
6 const int N=1010;
7 const ll P=998244353;
8 int n,p[N][4],x,y,inv[N*3],f[N][N*3],g[N*3],ans,size[N];
9 int head[N],to[N<<1],nxt[N<<1],st[N<<1],q;
10 inline int read() {
11 int re=0; char ch=getchar();
12 while (ch<'0'||ch>'9') ch=getchar();
13 while (ch>='0'&&ch<='9') re=re*10+ch-48,ch=getchar();
14 return re;
15 }
16 inline void add(int x,int y) {
17 to[++q]=y; nxt[q]=head[x]; head[x]=q; st[q]=1;
18 to[++q]=x; nxt[q]=head[y]; head[y]=q; st[q]=P-1;
19 }
20 ll qpow(ll x,ll k) {ll r=1; for (; k; k>>=1,x=x*x%P) if (k&1) r=r*x%P; return r;}
21 void dfs(int x,int fa) {
22 f[x][0]=1; int y;
23 for (int i=head[x]; i; i=nxt[i]) if (to[i]!=fa) {
24 dfs(to[i],x); y=to[i];
25 memset(g,0,sizeof g);
26 for (int j=0; j<=size[x]; j++) for (int k=0; k<=size[y]; k++) {
27 g[j+k]=(g[j+k] + 1ll * f[x][j] * f[y][k] %P * st[i] %P) %P;
28 if (st[i]!=1) g[j]=(g[j] + 1ll * f[x][j] * f[y][k] %P) %P;
29 }
30 memcpy(f[x],g,sizeof g); size[x]+=size[y];
31 }
32 memset(g,0,sizeof g);
33 for (int j=0; j<=size[x]; j++) for (int k=1; k<=3; k++) {
34 g[j+k]=(g[j+k] + 1ll * f[x][j] * k %P * inv[j+k] %P * p[x][k] %P) %P;
35 }
36 memcpy(f[x],g,sizeof g); size[x]+=3;
37 }
38 int main() {
39 n=read();
40 for (int i=1; i<=n; i++) {
41 x=0; for (int j=1; j<=3; j++) x+=(p[i][j]=read());
42 for (int j=1; j<=3; j++) p[i][j]=p[i][j]*qpow(x,P-2)%P;
43 }
44 for (int i=1; i<n; i++) x=read(),y=read(),add(x,y);
45 inv[1]=1; for (int i=2; i<=n*3; i++) inv[i]=1ll*(P-P/i)*inv[P%i]%P;
46 dfs(1,0);
47 for (int i=1; i<=3*n; i++) ans=(ans+f[1][i])%P;
48 printf("%d
",ans);
49 return 0;
50 }