题目:https://loj.ac/problem/3124
看了题解:https://www.cnblogs.com/Itst/p/10883880.html
先考虑外向树。
考虑分母是 ( sum w ) ,同样一个子树,其实不会因为子树外部分的 ( sum w ) 不同而对子树内的 DP 值有影响。
比如,在只考虑以子树内的 ( sum w ) 为分母的情况下做出了 “ cr 子树内部合法的方案数 f[cr] ”
设 ( W' = sumlimits_{i in tree_{cr}} w_i ) , ( W = sumlimits_{i=1}^{n} w_i )
考虑在以 W 为分母的情况下, cr 子树里的点是 cr 最先出现的概率 ( p = frac{w_{cr}}{W}sumlimits_{k=0}^{infty}(1-frac{W'}{W})^k = frac{w_{cr}}{W'} )
上面的式子就是表示在取到 cr 之前, cr 子树里的点都不能被取到。
注意到只考虑 cr 子树里的点的时候, cr 被第一个取到的概率就是 ( frac{w_{cr}}{W'} ) !所以可以认为一个子树在总体里的合法概率就是只有这个子树时的合法概率。
所以做 f[ cr ] 的时候可以直接用 f[ v ] 来乘啦!( v 是 cr 的一个孩子)
因为每个点 cr 往 DP 数组里贡献 ( frac{w_{cr}}{W'} ) ,与一个 W' 有关,所以把 “子树里的 w 和” 记在状态里。
那么 f[ cr ][ j ] 就是把孩子们的 f[ ][ ] 像背包一样合并起来,最后再加入自己的贡献,就是枚举 ( w_{cr} ) ,( f[cr][j]*p[cr][i]*frac{i}{j+i} -> f[cr][j+i] )
这样就是外向树的情况。
有反边,考虑容斥掉。暴力的做法是 2n 枚举哪些反边强制变成正边,那么其余的反边变成 “没有限制” ;如果有 cnt 条反边被强制变成正边,容斥系数就是 ( -1 )cnt 。
考虑不枚举,把系数放在 DP 的过程中。即:每次把一条反边强制变成正边,答案就要乘一个 -1 ;
在转移的时候,遇到反边,把 f[ v ][ * ] 乘上 -1 转移;另一种选择是把 ( sumlimits_j f[v][j] ) 直接乘到每个 f[cr][ ] 上,f[cr][ ] 的第二维不变,表示该边无限制(无限制的话,( frac{w_{cr}}{W'} ) 的 W' 就不应该包含 v 的 ( sum w ))
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=1005,M=N*3,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n,p[N][5],f[N][M],inv[M],siz[N],g[M]; int hd[N],xnt,to[N<<1],nxt[N<<1]; bool lx[N<<1]; void add(int x,int y,bool fx) {to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;lx[xnt]=fx;} void dfs(int cr,int fa) { f[cr][0]=1; siz[cr]=0; int tp; for(int i2=hd[cr],v;i2;i2=nxt[i2]) if((v=to[i2])!=fa) { dfs(v,cr); if(lx[i2])//lx[i2] not lx[cr]!!!!!! { for(int i=siz[cr];i>=0;i--) { tp=f[cr][i]; f[cr][i]=0; for(int j=1;j<=siz[v];j++)//j=1 is ok f[cr][i+j]=(f[cr][i+j]+(ll)tp*f[v][j])%mod; } } else { int sm=0;for(int i=0;i<=siz[v];i++)sm=upt(sm+f[v][i]); for(int i=siz[cr];i>=0;i--) { tp=f[cr][i]; f[cr][i]=0; for(int j=1;j<=siz[v];j++) f[cr][i+j]=upt(f[cr][i+j]-(ll)tp*f[v][j]%mod); f[cr][i]=(ll)tp*sm%mod; } } siz[cr]+=siz[v]; } for(int i=siz[cr];i>=0;i--) { tp=f[cr][i]; f[cr][i]=0; for(int w=1;w<=3;w++) f[cr][i+w]=(f[cr][i+w]+(ll)p[cr][w]*tp%mod*w%mod*inv[i+w])%mod; } siz[cr]+=3; } int main() { n=rdn(); for(int i=1,a,b,c,s;i<=n;i++) { a=rdn();b=rdn();c=rdn(); s=a+b+c; s=pw(s,mod-2); p[i][1]=(ll)a*s%mod; p[i][2]=(ll)b*s%mod; p[i][3]=(ll)c*s%mod; } for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v,1),add(v,u,0); inv[1]=1; for(int i=2;i<=n*3;i++) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; dfs(1,0); int ans=0; for(int i=0;i<=siz[1];i++)ans=upt(ans+f[1][i]); printf("%d ",ans); return 0; }