我们考虑计算红色点与非红色点的对数。
我们用f[i][j]表示i的子树中有j个红色点的概率,将i所有子树合并。
接着我们对于每一个状态,枚举i是红色还是非红色算概率。
同时我们可以求出i和i子树内一个是红色一个是非红色的期望对数。
同理我们计算出黑与非黑,白与非白。
由于红与非红=红黑+红白,黑与非黑=红黑+黑白,白与非白=红白+黑白,因此我们可以把红黑、红白和黑白算出来。
下面这份代码大部分都是模板,自行忽略...
#include <iostream> #include <stdio.h> #include <math.h> #include <string.h> #include <time.h> #include <stdlib.h> #include <string> #include <bitset> #include <vector> #include <set> #include <map> #include <queue> #include <algorithm> #include <sstream> #include <stack> #include <iomanip> using namespace std; #define pb push_back #define mp make_pair typedef pair<int,int> pii; typedef long long ll; typedef double ld; typedef vector<int> vi; #define fi first #define se second #define fe first #define FO(x) {freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);} #define es(x,e) (int e=fst[x];e;e=nxt[e]) #define VIZ {printf("digraph G{ "); for(int i=1;i<=n;i++) for es(i,e) printf("%d->%d; ",i,vb[e]); puts("}");} #define SZ 66666 int n,m,fst[SZ],vb[SZ],nxt[SZ],M=0; ll MOD=998244353; void ad_de(int a,int b) { ++M; nxt[M]=fst[a]; fst[a]=M; vb[M]=b; } void adde(int a,int b) { ad_de(a,b); ad_de(b,a); } #define S 20 int dep[SZ],fa[SZ],p[SZ][S],sz[SZ]; void dfs(int x,int f=0) { sz[x]=1; for es(x,e) { int b=vb[e]; if(b==f) continue; fa[b]=p[b][0]=x; dep[b]=dep[x]+1; dfs(b,x); sz[x]+=sz[b]; } } void pre() { dfs(1); for(int i=1;i<S;i++) { for(int j=1;j<=n;j++) p[j][i]=p[p[j][i-1]][i-1]; } } int jump(int x,int d) { for(int s=S-1;s>=0;s--) { if(p[x][s]&&dep[p[x][s]]>=d) x=p[x][s]; } if(dep[x]!=d) exit(-1); return x; } int lca(int a,int b) { if(dep[a]>dep[b]) swap(a,b); b=jump(b,dep[a]); if(a==b) return a; for(int s=S-1;s>=0;s--) { if(p[a][s]!=p[b][s]) a=p[a][s], b=p[b][s]; } return p[a][0]; } int dis(int a,int b) { return dep[a]+dep[b]-dep[lca(a,b)]*2; } #define gc getchar() int gint() { int s=0,t=0; while(s=gc,s<'0'||s>'9');t=s-48; while(s=gc,s>='0'&&s<='9') t=t*10+s-48; return t; } #define gi gint() ll qp(ll a,ll b) { a%=MOD; ll ans=1; while(b) { if(b&1) ans=ans*a%MOD; a=a*a%MOD; b>>=1; } return ans; } ll gx[23],rp[SZ][3]; ll gl[1003][1003][3]; ll tmp[1003]; bool lv[1003]; ll tot[3]; void dfs2(int x,int f=0) { if(lv[x]) { for(int j=0;j<3;j++) gl[x][0][j]=(1-rp[x][j])%MOD, gl[x][1][j]=rp[x][j]; return; } int cs=0; gl[x][0][0]=gl[x][0][1]=gl[x][0][2]=1; for es(x,e) { int b=vb[e]; if(b==f) continue; dfs2(b,x); int cc=cs+sz[b]; for(int k=0;k<3;k++) { for(int i=0;i<=cc;i++) tmp[i]=0; for(int i=0;i<=cs;i++) { for(int j=0;j<=sz[b];j++) { tmp[i+j]=(tmp[i+j]+gl[x][i][k]*gl[b][j][k]%MOD)%MOD; } } for(int i=0;i<=cc;i++) gl[x][i][k]=tmp[i]; } cs=cc; } for(int k=0;k<3;k++) { for(int i=0;i<=cs+1;i++) tmp[i]=0; for(int i=0;i<=cs;i++) { ll y=i*qp(cs,MOD-2)%MOD; ll g1=y*gl[x][i][k]%MOD; tot[k]=(tot[k]+g1*(cs-i)%MOD)%MOD; ll g2=(1-y)*gl[x][i][k]%MOD; tot[k]=(tot[k]+g2*i%MOD)%MOD; tmp[i+1]=(tmp[i+1]+g1)%MOD; tmp[i]=(tmp[i]+g2)%MOD; } for(int i=0;i<=cs+1;i++) gl[x][i][k]=tmp[i]; } } int main() { FO(tree3) n=gi, gx[01]=gi, gx[02]=gi, gx[12]=gi; for(int i=1;i<n;i++) { int x=gi, y=gi; adde(x,y); } for(int i=1;i<=n;i++) { for(int j=0;j<3;j++) rp[i][j]=gi; } pre(); for(int x=1;x<=n;x++) { bool leaf=1; for es(x,e) leaf&=(vb[e]==fa[x]); lv[x]=leaf; } dfs2(1); ll aa=(tot[0]+tot[1]+tot[2])%MOD*qp(2,MOD-2)%MOD,ans=0; ans=ans+gx[01]*(aa-tot[2])%MOD; ans=ans+gx[02]*(aa-tot[1])%MOD; ans=ans+gx[12]*(aa-tot[0])%MOD; ans=(ans%MOD+MOD)%MOD; cout<<ans<<" "; }