题目大意:给出一颗含有$n$个结点的树,每个节点有一个颜色。求树中每个子树最多的颜色的编号和。
-------------------------
树上启发式合并(dsu on tree)。
我们先考虑暴力怎么做。遍历整颗树,暴力枚举子树然后用桶维护颜色个数。这样做是$O(n^2)$的,显然会T。我们需要一种更快的算法:树上启发式合并。
关于启发式算法的介绍,详见OI Wiki。本文只介绍树上启发式合并算法。本题的解法:
每处理完一颗子树,我们都要把桶清空一次,以免对它的兄弟造成影响。而这样做还要从它的祖先遍历一遍,浪费时间。
我们发现:遍历最后一颗子树时,桶是不用清空的。因为遍历完那颗子树后可以直接把答案加入$ans$中。那我们肯定选重儿子啊,省时省力。遍历轻儿子相对不费事。
看起来是不是没有快多少?实际上它是$O(nlog n)$的。下面是证明:
对于每个节点,它被计算的次数就是它到根节点路径的轻边个数。
而结点往上跳一次,子树大小至少为原来两倍,所以轻边个数最多是$log n$。所以时间复杂度$O(nlog n)$。
证明过程跟树链剖分的有点像。
代码:
#include<bits/stdc++.h> #define int long long using namespace std; int n,color[200005],bucket[200005],ans[200005]; int size[200005],son[200005],sum,mx; int head[200005],cnt; struct node { int next,to; }edge[200005]; inline int read() { int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();} while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void add(int from,int to) { edge[++cnt].next=head[from]; edge[cnt].to=to; head[from]=cnt; } inline void dfs_son(int now,int fa) { size[now]=1; int mx=0,p=0; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==fa) continue; dfs_son(to,now); size[now]+=size[to]; if (size[to]>mx) { mx=size[to]; p=to; } } if (p) son[p]=1; } void getans(int x,int f,int p){ bucket[color[x]]++; if(bucket[color[x]]>mx){ mx=bucket[color[x]]; sum=color[x]; }else if(bucket[color[x]]==mx)sum+=color[x]; for(int i=head[x];i;i=edge[i].next){ int y=edge[i].to; if(y==f || y==p)continue; getans(y,x,p); } } inline void init(int now,int fa) { bucket[color[now]]--; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==fa) continue; init(to,now); } } inline void dfs(int now,int fa) { int p=0; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==fa) continue; if (!son[to]) { dfs(to,now); init(to,now); sum=mx=0; } else p=to; } if (p) dfs(p,now); getans(now,fa,p); ans[now]=sum; } signed main() { n=read(); for (int i=1;i<=n;i++) color[i]=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); add(x,y);add(y,x); } dfs_son(1,0); dfs(1,0); for (int i=1;i<=n;i++) printf("%lld ",ans[i]); return 0; }