https://www.luogu.org/problemnew/show/P2664
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<vector> 5 using namespace std; 6 #define fi first 7 #define se second 8 #define mp make_pair 9 #define pb push_back 10 typedef long long ll; 11 typedef unsigned long long ull; 12 typedef pair<int,int> pii; 13 struct E 14 { 15 int to,nxt; 16 }e[200111]; 17 int f1[100011],ne; 18 int sz[100011],a[100011]; 19 int n; 20 ll t1[100011],t2[100011],s,ans[100011]; 21 void dfs1(int u,int fa) 22 { 23 sz[u]=1; 24 int v; 25 ll t=t1[a[u]],z=t1[a[fa]]; 26 for(int k=f1[u];k;k=e[k].nxt) 27 if(e[k].to!=fa) 28 { 29 v=e[k].to; 30 dfs1(v,u); 31 sz[u]+=sz[v]; 32 } 33 t1[a[u]]=t+sz[u]; 34 t2[u]=t1[a[fa]]-z; 35 } 36 void dfs2(int u,int fa) 37 { 38 int v;ll ta; 39 ans[u]=s; 40 for(int k=f1[u];k;k=e[k].nxt) 41 if(e[k].to!=fa) 42 { 43 v=e[k].to; 44 ta=t1[a[v]]; 45 s+=n-t1[a[v]]; 46 t1[a[v]]=n; 47 s+=t2[v]-sz[v]; 48 t1[a[u]]+=t2[v]-sz[v]; 49 dfs2(v,u); 50 s+=ta-t1[a[v]]; 51 t1[a[v]]=ta; 52 s-=t2[v]-sz[v]; 53 t1[a[u]]-=t2[v]-sz[v]; 54 } 55 } 56 int main() 57 { 58 int i,x,y; 59 scanf("%d",&n); 60 for(i=1;i<=n;++i) 61 scanf("%d",a+i); 62 for(i=1;i<n;++i) 63 { 64 scanf("%d%d",&x,&y); 65 e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne; 66 e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne; 67 } 68 dfs1(1,0); 69 for(i=1;i<=100000;++i) 70 s+=t1[i]; 71 dfs2(1,0); 72 for(i=1;i<=n;++i) 73 printf("%lld ",ans[i]); 74 return 0; 75 }