又臭又长的点分治+树上乱搞
1 #include<bits/stdc++.h> 2 #define ll long long 3 using namespace std; 4 const int N=100010,oo=INT_MAX; 5 int n,color[N],cnt[N],cnt2[N],sub[N],nowchild; 6 ll sum[N]; 7 vector<int>g[N]; 8 int f[N],siz[N],root,tot_node,sumsiz[N],cursumsiz[N]; 9 int timer,timer2,colorvis[N],curcolorvis[N],isfirst[N]; 10 bool vis[N]; 11 void getroot(int k,int fa){ 12 siz[k]=1,f[k]=0; 13 int x; 14 for(int i=0;i<g[k].size();i++){ 15 x=g[k][i]; 16 if(x==fa||vis[x]) continue; 17 getroot(x,k); 18 siz[k]+=siz[x],f[k]=max(f[k],siz[x]); 19 } 20 f[k]=max(f[k],tot_node-siz[k]); 21 if(f[root]>f[k]) root=k; 22 return; 23 } 24 void dfs0(int k,int fa){ 25 siz[k]=1; 26 int x; 27 for(int i=0;i<g[k].size();i++){ 28 x=g[k][i]; 29 if(x==fa||vis[x]) continue; 30 dfs0(x,k); 31 siz[k]+=siz[x]; 32 } 33 return; 34 } 35 void getsum(int k,int fa){ 36 int x; 37 if(cnt[color[k]]==1&&fa){ 38 if(colorvis[color[k]]!=timer) colorvis[color[k]]=timer,sumsiz[color[k]]=0; 39 sumsiz[color[k]]+=siz[k]; 40 } 41 for(int i=0;i<g[k].size();i++){ 42 x=g[k][i]; 43 if(x==fa||vis[x]) continue; 44 cnt[color[x]]++; 45 getsum(x,k); 46 } 47 cnt[color[k]]--; 48 } 49 void fix(int k,int fa){ 50 int x; 51 if(isfirst[k]==timer){ 52 sub[k]-=sumsiz[color[k]]-cursumsiz[color[k]]; 53 } 54 for(int i=0;i<g[k].size();i++){ 55 x=g[k][i]; 56 if(x==fa||vis[x]) continue; 57 fix(x,k); 58 } 59 return; 60 } 61 void dfs1(int k,int fa,int r){ 62 int x; 63 if(cnt[color[k]]==1&&fa){ 64 isfirst[k]=timer; 65 if(curcolorvis[color[k]]!=timer2) curcolorvis[color[k]]=timer2,cursumsiz[color[k]]=0; 66 sub[k]+=tot_node-siz[nowchild],sub[r]+=siz[k],sub[nowchild]-=siz[k]; 67 cursumsiz[color[k]]+=siz[k]; 68 } 69 for(int i=0;i<g[k].size();i++){ 70 x=g[k][i]; 71 if(x==fa||vis[x]) continue; 72 if(!fa) nowchild=x,timer2++; 73 cnt[color[x]]++; 74 dfs1(x,k,r); 75 if(!fa) fix(x,k); 76 } 77 cnt[color[k]]--; 78 return; 79 } 80 void dfs2(int k,int fa){ 81 int x; 82 sum[k]+=sub[k]; 83 for(int i=0;i<g[k].size();i++){ 84 x=g[k][i]; 85 if(x==fa||vis[x]) continue; 86 sub[x]+=sub[k]; 87 dfs2(x,k); 88 } 89 sub[k]=0; 90 return; 91 } 92 void calc(int k){ 93 int x; 94 timer++; 95 dfs0(k,0); 96 cnt[color[k]]++; 97 getsum(k,0); 98 cnt[color[k]]++; 99 dfs1(k,0,k); 100 for(int i=0;i<g[k].size();i++){ 101 x=g[k][i]; 102 if(vis[x]) continue; 103 sub[x]+=tot_node-siz[x]; 104 } 105 sum[k]+=tot_node; 106 dfs2(k,0); 107 return; 108 } 109 void work(int k){ 110 vis[k]=1; 111 int x; 112 calc(k); 113 for(int i=0;i<g[k].size();i++){ 114 x=g[k][i]; 115 if(vis[x]) continue; 116 root=0,tot_node=siz[x]; 117 getroot(x,0); 118 work(root); 119 } 120 return; 121 } 122 int main(){ 123 int t1,t2; 124 f[0]=oo; 125 scanf("%d",&n); 126 for(int i=1;i<=n;i++) scanf("%d",&color[i]); 127 for(int i=1;i<n;i++){ 128 scanf("%d%d",&t1,&t2); 129 g[t1].push_back(t2);g[t2].push_back(t1); 130 } 131 root=0,tot_node=n; 132 getroot(1,0); 133 work(root); 134 for(int i=1;i<=n;i++)printf("%lld ",sum[i]); 135 return 0; 136 }