code:
#include <bits/stdc++.h> #define N 200009 #define ll long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; ll Sum[N]; int n,edges,root,sn; int val[N],hd[N],to[N<<1],nex[N<<1],size[N],mx[N],vis[N],A[N]; inline void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void getroot(int u,int ff) { size[u]=1,mx[u]=0; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; getroot(v,u); size[u]+=size[v]; mx[u]=max(mx[u],size[v]); } mx[u]=max(mx[u],sn-size[u]); if(mx[u]<mx[root]) root=u; } int ou; ll tmp,tot,bu[N]; map<int,ll>cn[N]; map<int,ll>::iterator it; int dep[N],cnt[N],siz[N]; void getnode(int top,int u,int ff,int cur) { if(!cnt[val[u]]) ++cur; ++cnt[val[u]]; Sum[top]+=(ll)cur; siz[u]=1; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; getnode(top,v,u,cur); siz[u]+=siz[v]; } --cnt[val[u]]; } void get_col(int top,int u,int ff) { if(!cnt[val[u]]) cn[top][val[u]]+=(ll)siz[u]; ++cnt[val[u]]; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; get_col(top,v,u); } --cnt[val[u]]; } void calc_v(int u,int ff) { ll tt=bu[val[u]]; tmp=tmp-bu[val[u]]+ou; bu[val[u]]=ou; Sum[u]+=tmp; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]||v==ff) continue; calc_v(v,u); } tmp=tmp-bu[val[u]]+tt; bu[val[u]]=tt; } void clr(int u,int ff) { cn[u].clear(); bu[val[u]]=0; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; clr(v,u); } } void calc(int u) { tot=0; getnode(u,u,0,0); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; // memset(cnt,0,sizeof(cnt)); get_col(v,v,u); for(it=cn[v].begin();it!=cn[v].end();it++) { tot+=it->second; bu[it->first]+=it->second; } } for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; tmp=tot; ou=siz[u]-siz[v]; for(it=cn[v].begin();it!=cn[v].end();it++) { bu[it->first]-=it->second; tmp-=it->second; } ll tt=bu[val[u]]; tmp=tmp-bu[val[u]]+ou; bu[val[u]]=ou; calc_v(v,u); bu[val[u]]=tt; for(it=cn[v].begin();it!=cn[v].end();it++) bu[it->first]+=it->second; } clr(u,0); } void dfs(int u) { calc(u); vis[u]=1; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; root=0,sn=size[v],getroot(v,u),dfs(root); } } int main() { // setIO("input"); int i,j; scanf("%d",&n); for(i=1;i<=n;++i) scanf("%d",&val[i]), A[i]=val[i]; sort(A+1,A+1+n); for(i=1;i<=n;++i) val[i]=lower_bound(A+1,A+1+n,val[i])-A; for(i=1;i<n;++i) { int u,v; scanf("%d%d",&u,&v),add(u,v),add(v,u); } sn=mx[0]=n,root=0,getroot(1,0),dfs(root); for(i=1;i<=n;++i) printf("%lld ",Sum[i]); return 0; }