题目内容
给出一个树,求出每个节点的子树中出现次数最多的颜色的编号和。
思路
那些年,关于板子题的那些事
代码
Dsu on tree:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1e5+10;
int n;
int col[maxn],res[maxn];
struct Edge{
int from,to,nxt;
}e[maxn<<1];
inline int read(){
int x=0;bool fopt=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')fopt=0;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-48;
return fopt?x:-x;
}
int head[maxn],cnt;
inline void add(int u,int v){
e[++cnt].from=u;
e[cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
int dep[maxn],siz[maxn],son[maxn],fa[maxn];
void dfs(int u){
dep[u]=dep[fa[u]]+1;siz[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u])continue;
fa[v]=u;
dfs(v);
siz[u]+=siz[v];
if(!son[u]||siz[son[u]]<siz[v])son[u]=v;
}
}
int Max=-1,sum;
int buc[maxn];
void modify(int u,int val,int pos){
buc[col[u]]+=val;
if(buc[col[u]]>Max){
Max=buc[col[u]];sum=col[u];
}else if(buc[col[u]]==Max)sum+=col[u];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==pos)continue;
modify(v,val,pos);
}
}
void dsu(int u,bool isson){
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dsu(v,0);
}
if(son[u])dsu(son[u],1);
modify(u,1,son[u]);
res[u]=sum;
if(!isson){
modify(u,-1,0);
sum=Max=0;
}
}
signed main(){
n=read();
for(int i=1;i<=n;i++)
col[i]=read();
for(int i=1;i<n;i++){
int u=read(),v=read();
add(u,v);add(v,u);
}
dfs(1);
dsu(1,0);
for(int i=1;i<=n;i++)
printf("%lld ",res[i]);
return 0;
}
线段树合并:
#include <bits/stdc++.h>
#define int long long
#define lson tr[rt].l
#define rson tr[rt].r
using namespace std;
const int maxn=1e5+10;
int n,tot;
int col[maxn],root[maxn],res[maxn];
struct Edge{
int from,to,nxt;
}e[maxn<<1];
struct SEG{
int l,r,sum,ans;
}tr[maxn*32];
inline int read(){
int x=0;bool fopt=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')fopt=0;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+ch-48;
return fopt?x:-x;
}
int head[maxn],cnt;
inline void add(int u,int v){
e[++cnt].from=u;
e[cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
inline void pushup(int rt){
int u=tr[lson].sum>=tr[rson].sum?lson:rson;
tr[rt].sum=tr[u].sum;
if(tr[lson].sum==tr[rson].sum)
tr[rt].ans=tr[lson].ans+tr[rson].ans;
else tr[rt].ans=tr[u].ans;
}
inline void modify(int &rt,int l,int r,int pos,int v){
if(!rt)rt=++tot;
if(l==r){
tr[rt].sum+=v;
tr[rt].ans=l;
return;
}
int mid=(l+r)>>1;
if(pos<=mid)modify(lson,l,mid,pos,v);
else modify(rson,mid+1,r,pos,v);
pushup(rt);
}
inline int Merge(int a,int b,int l,int r){
if(!a||!b)return a+b;
if(l==r){
tr[a].sum+=tr[b].sum;
tr[a].ans=l;
return a;
}
int mid=(l+r)>>1;
tr[a].l=Merge(tr[a].l,tr[b].l,l,mid);
tr[a].r=Merge(tr[a].r,tr[b].r,mid+1,r);
pushup(a);
return a;
}
void dfs(int u,int fa){
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa)continue;
dfs(v,u);
Merge(root[u],root[v],1,1e5);
}
modify(root[u],1,1e5,col[u],1);
res[u]=tr[root[u]].ans;
}
signed main(){
n=read();tot=n;
for(int i=1;i<=n;i++){
col[i]=read();
root[i]=i;
}
for(int i=1;i<n;i++){
int u=read(),v=read();
add(u,v);add(v,u);
}
dfs(1,0);
for(int i=1;i<=n;i++)
printf("%lld ",res[i]);
return 0;
}