发现其实就是左右2棵子树,左儿子选到某个值的概率就是
右儿子类似
由于保证所有值不都相等
然后上线段树合并的时候一下就完了
复杂度
#include<bits/stdc++.h>
using namespace std;
const int RLEN=1<<20|1;
inline char gc(){
static char ibuf[RLEN],*ib,*ob;
(ib==ob)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
return (ib==ob)?EOF:*ib++;
}
inline int read(){
char ch=gc();
int res=0,f=1;
while(!isdigit(ch))f^=ch=='-',ch=gc();
while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
return f?res:-res;
}
const int mod=998244353;
const int N=500005;
inline int add(int a,int b){
return a+b>=mod?a+b-mod:a+b;
}
inline void selfadd(int &a,int b){
a=add(a,b);
}
inline int dec(int a,int b){
return a>=b?a-b:a-b+mod;
}
inline int mul(int a,int b){
return 1ll*a*b>=mod?1ll*a*b%mod:a*b;
}
inline void selfmul(int &a,int b){
a=mul(a,b);
}
inline int ksm(int a,int b,int res=1){
for(;b;b>>=1,a=mul(a,a))if(b&1)selfmul(res,a);
return res;
}
int cnt[N],son[N][2],rt[N],val[N],pmx[N];
int fa[N],ori[N],ans,n;
struct node{
int k,p;
friend inline bool operator <(const node &a,const node &b){
return a.k<b.k;
}
}q[N];int tot;
namespace Seg{
#define mid ((l+r)>>1)
int tot,cnt,mxa,mxb;
int lc[N*22],rc[N*22],tag[N*22],tr[N*22];
void update(int &u,int l,int r,int k){
u=++tot,tag[u]=1,tr[u]=1;
if(l==r)return;
if(k<=mid)update(lc[u],l,mid,k);
else update(rc[u],mid+1,r,k);
}
inline void pushup(int u){
tr[u]=add(tr[lc[u]],tr[rc[u]]);
}
inline void pushnow(int u,int t){
selfmul(tag[u],t);
selfmul(tr[u],t);
}
inline void pushdown(int u){
if(tag[u]==1)return;
int &t=tag[u];
if(lc[u])pushnow(lc[u],t);if(rc[u])pushnow(rc[u],t);
t=1;
}
int merge(int r1,int r2,int g){
if(!r1&&!r2)return 0;
pushdown(r1),pushdown(r2);
if(!r1){
selfadd(mxb,tr[r2]);
pushnow(r2,dec(add(mxa,g),mul(2,mul(mxa,g))));
return r2;
}
if(!r2){
selfadd(mxa,tr[r1]);
pushnow(r1,dec(add(mxb,g),mul(2,mul(mxb,g))));;
return r1;
}
rc[r1]=merge(rc[r1],rc[r2],g);
lc[r1]=merge(lc[r1],lc[r2],g);
pushup(r1);return r1;
}
void dfs(int u){
if(!::cnt[u])return;
else if(::cnt[u]==1)dfs(son[u][0]),rt[u]=rt[son[u][0]];
else if(::cnt[u]==2){
dfs(son[u][0]),dfs(son[u][1]);
mxa=mxb=0;
rt[u]=merge(rt[son[u][0]],rt[son[u][1]],pmx[u]);
}
}
void calc(int u,int l,int r){
if(!tr[u])return;
pushdown(u);
if(l==r){
cnt++;
selfadd(ans,mul(mul(mul(cnt,ori[l]),tr[u]),tr[u]));
return;
}
calc(lc[u],l,mid),calc(rc[u],mid+1,r);
}
}
signed main(){
n=read();
for(int i=1;i<=n;i++){
fa[i]=read();
son[fa[i]][cnt[fa[i]]++]=i;
}
int inv=ksm(10000,mod-2);
for(int i=1;i<=n;i++){
int k=read();
if(cnt[i])pmx[i]=mul(k,inv);
else{q[++tot]=(node){k,i};}
}
sort(q+1,q+tot+1);
for(int i=1;i<=tot;i++){
ori[i]=q[i].k,Seg::update(rt[q[i].p],1,tot,i);
}
Seg::dfs(1);
Seg::calc(rt[1],1,tot);
cout<<ans;
}