离散化后,容易想到设f[i][j]为i节点权值为j的概率,不妨设j权值在左子树,则有f[i][j]=f[lson][j](pi·f[rson][1~j]+(1-pi)·f[rson][j~m])。
考虑用线段树合并优化这个dp。记录前缀和,合并某节点时,若某棵线段树在该节点处为空,给另一棵线段树打上乘法标记即可。注意前缀和不要算成合并后的了。
#include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorithm> using namespace std; #define ll long long #define N 300010 #define P 998244353 char getc(){char c=getchar();while ((c<'A'||c>'Z')&&(c<'a'||c>'z')&&(c<'0'||c>'9')) c=getchar();return c;} int gcd(int n,int m){return m==0?n:gcd(m,n%m);} int read() { int x=0,f=1;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } int n,m,p[N],a[N],b[N],fa[N],root[N],t,ans,cnt; struct data{int to,nxt; }edge[N]; void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;} struct data2{int l,r,x,lazy; }tree[N<<5]; void ins(int &k,int l,int r,int x) { if (!k) k=++cnt;tree[k].x=1; if (l==r) return; int mid=l+r>>1; if (x<=mid) ins(tree[k].l,l,mid,x); else ins(tree[k].r,mid+1,r,x); } void update(int k,int x) { if (!k) return; tree[k].x=1ll*tree[k].x*x%P; if (tree[k].lazy) tree[k].lazy=1ll*tree[k].lazy*x%P; else tree[k].lazy=x; } void down(int k){update(tree[k].l,tree[k].lazy),update(tree[k].r,tree[k].lazy),tree[k].lazy=0;} int query(int k,int l,int r,int x) { if (!k) return 0; if (l==r) return tree[k].x; if (tree[k].lazy) down(k); int mid=l+r>>1; if (x<=mid) return query(tree[k].l,l,mid,x); else return query(tree[k].r,mid+1,r,x); } int merge(int x,int y,int l,int r,int s0,int s1,int p) { if (tree[x].lazy) down(x); if (tree[y].lazy) down(y); if (!x||!y) { if (!x) x=y,swap(s0,s1); update(x,(1ll*p*s1+1ll*(P+1-p)*(P+1-s1))%P); return x; } if (l<r) { int mid=l+r>>1; tree[x].r=merge(tree[x].r,tree[y].r,mid+1,r,(s0+tree[tree[x].l].x)%P,(s1+tree[tree[y].l].x)%P,p); tree[x].l=merge(tree[x].l,tree[y].l,l,mid,s0,s1,p), tree[x].x=(tree[tree[x].l].x+tree[tree[x].r].x)%P; } return x; } void dfs(int k) { int s=0; for (int i=p[k];i;i=edge[i].nxt) dfs(edge[i].to),s++; if (s==0) ins(root[k],1,m,a[k]); else if (s==1) root[k]=root[edge[p[k]].to]; else root[k]=merge(root[edge[p[k]].to],root[edge[edge[p[k]].nxt].to],1,m,0,0,a[k]); } int main() { #ifndef ONLINE_JUDGE freopen("bzoj5461.in","r",stdin); freopen("bzoj5461.out","w",stdout); const char LL[]="%I64d "; #else const char LL[]="%lld "; #endif n=read(); for (int i=1;i<=n;i++) { int x=read(); fa[i]=x,addedge(x,i); } for (int i=1;i<=n;i++) { a[i]=read(); if (p[i]) a[i]=1ll*a[i]*796898467%P; else b[++m]=a[i]; } sort(b+1,b+m+1); for (int i=1;i<=n;i++) if (!p[i]) a[i]=lower_bound(b+1,b+m+1,a[i])-b; dfs(1); for (int i=1;i<=m;i++) { int x=query(root[1],1,m,i); ans=(ans+1ll*i*b[i]%P*x%P*x)%P; } cout<<ans; return 0; }