题意简述
给定一棵 (n) 的节点的树,根为1,每个点有权值 (M_i)
要把树分成若干段,每段内不存在“祖先-后代”关系,定义一个段的大小为段中点 (M_i) 的最大值。
求所有段的大小之和的最小值。
(nleq 2 imes 10^5)
想法
想法一:奇怪的贪心
每次找到全树中没分到段内的有最大权值的点 (u) ,则 (M_u) 无论如何都要加到最终答案中。
那么下一步贪心就是找到非 (u) 的祖先且非 (u) 的后代的点中有最大权值的点 (v) ,与 (u) 分到一个段中。
然后再找与 (u) 、(v) 都不存在 “祖先-后代”关系的有最大权值的点加入该段中……以此类推,直到找不到可加入的点,这一个段结束。
要证的话,大概就交换一下。
假设最优解中 (v) 与 (u) 不在同一个段中,将 (v) 与 (u) 所在段中与 (v) 不兼容的点交换位置,之后仍满足要求且不会更差。
怎样找非“祖先-后代”关系的点中权值最大的点呢?
树剖+线段树。
段中每加入一个点,就把它的祖先和后代在线段树中“盖住”,一个段结束后再统一把所有“盖子”都去掉。
很坑,细节极多,实在需要好好注意。(我在细节写炸后还怀疑是算法错了呢 (qwq) )
想法二:靠谱一些的贪心
树上问题,先考虑子树。
假设已求出子树最优情况下各个段的大小,在对子树进行合并时,显然各自树的最大段合成一段,次大段合成一段,以此类推……
(我也不知道怎么证,但看起来就很靠谱)
对每个子树搞个堆,启发式合并就可以了。
总结
树上问题,考虑子树……
代码
启发式合并
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>
using namespace std;
int read(){
int x=0;
char ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x;
}
const int N = 200005;
typedef long long ll;
struct node{
int v;
node *nxt;
}pool[N],*h[N];
int cnt;
void addedge(int u,int v){
node *p=&pool[++cnt];
p->v=v;p->nxt=h[u];h[u]=p;
}
int n,M[N],f[N],tl;
int st[N];
priority_queue<int> q[N];
void merge(int u,int v){
if(q[u].size()<q[v].size()) swap(q[u],q[v]);
tl=0;
while(q[v].size()){
st[tl++]=max(q[v].top(),q[u].top());
q[v].pop(); q[u].pop();
}
for(int i=0;i<tl;i++) q[u].push(st[i]);
}
void work(int u){
int v;
for(node *p=h[u];p;p=p->nxt)
work(v=p->v),merge(u,v);
q[u].push(M[u]);
}
int main()
{
n=read();
for(int i=1;i<=n;i++) M[i]=read();
for(int i=2;i<=n;i++) f[i]=read(),addedge(f[i],i);
work(1);
ll ans=0;
for(;q[1].size();) ans+=q[1].top(),q[1].pop();
printf("%lld
",ans);
return 0;
}
树剖+线段树
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
int read(){
int x=0;
char ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x;
}
const int N = 200005;
typedef pair<int,int> Pr;
typedef long long ll;
int n,M[N],f[N];
struct node{
int v;
node *nxt;
}pool[N],*h[N];
int cnt1;
void addedge(int u,int v){
node *p=&pool[++cnt1];
p->v=v;p->nxt=h[u];h[u]=p;
}
int dfn[N],tot,son[N],sz[N],top[N],re[N],out[N];
void dfs1(int u){
int v,Mson=0;
sz[u]=1;
for(node *p=h[u];p;p=p->nxt){
dfs1(v=p->v);
sz[u]+=sz[v];
if(sz[v]>Mson) son[u]=v,Mson=sz[v];
}
}
void dfs2(int u){
int v=son[u];
if(v){
top[v]=top[f[v]];
dfn[v]=++tot;
re[tot]=v;
dfs2(v);
}
for(node *p=h[u];p;p=p->nxt)
if(!dfn[v=p->v]){
top[v]=v;
dfn[v]=++tot;
re[tot]=v;
dfs2(v);
}
out[u]=tot;
}
int cnt,root,ch[N*2][2],cov[N*2],use[N*2];
Pr ori[N*2],mx[N*2];
void build(int x,int l,int r){
cov[x]=0;
if(l==r) { mx[x]=Pr(M[re[l]],re[l]); ori[x]=mx[x]; return; }
int mid=(l+r)>>1;
build(ch[x][0]=++cnt,l,mid);
build(ch[x][1]=++cnt,mid+1,r);
mx[x]=max(mx[ch[x][0]],mx[ch[x][1]]);
ori[x]=mx[x];
}
void push(int x){
if(!x) return;
use[x]=1; cov[x]=1; mx[x]=Pr(0,0);
}
void pushdown(int x){
if(!cov[x]) return;
push(ch[x][0]); push(ch[x][1]);
cov[x]=0;
}
void modify(int x,int l,int r,int L,int R){
use[x]=1;
if(L<=l && r<=R) { push(x); return; }
pushdown(x);
int mid=(l+r)>>1;
if(L<=mid) modify(ch[x][0],l,mid,L,R);
if(R>mid) modify(ch[x][1],mid+1,r,L,R);
mx[x]=max(mx[ch[x][0]],mx[ch[x][1]]);
}
void recover(int x,int l,int r){
if(!use[x]) return;
use[x]=0; cov[x]=0; mx[x]=ori[x];
if(l==r) return;
int mid=(l+r)>>1;
recover(ch[x][0],l,mid); recover(ch[x][1],mid+1,r);
}
Pr Max(int x,int l,int r,int L,int R){
if(L<=l && r<=R) return mx[x];
pushdown(x);
int mid=(l+r)>>1;
Pr ret(0,0);
if(L<=mid) ret=max(ret,Max(ch[x][0],l,mid,L,R));
if(R>mid) ret=max(ret,Max(ch[x][1],mid+1,r,L,R));
return ret;
}
void change(int x,int l,int r,int c){
use[x]=1;
if(l==r) { mx[x]=ori[x]=Pr(0,0); return; }
pushdown(x);
int mid=(l+r)>>1;
if(c<=mid) change(ch[x][0],l,mid,c);
else change(ch[x][1],mid+1,r,c);
mx[x]=max(mx[ch[x][0]],mx[ch[x][1]]);
ori[x]=max(ori[ch[x][0]],ori[ch[x][1]]);
}
void jump(int x){
modify(root,1,n,dfn[x],out[x]);/**/
while(x){
modify(root,1,n,dfn[top[x]],dfn[x]);
x=f[top[x]];
}
}
int main()
{
n=read();
for(int i=1;i<=n;i++) M[i]=read();
for(int i=2;i<=n;i++) f[i]=read(),addedge(f[i],i);
dfs1(1);
top[1]=1; dfn[1]=++tot; re[tot]=1; dfs2(1);
build(root=++cnt,1,n);
ll ans=0;
int t=n;
while(t){
Pr w=Max(root,1,n,1,n);
ans+=w.first;
change(root,1,n,dfn[w.second]); t--; /*dfn*/
jump(w.second);
for(;;){
w=Max(root,1,n,1,n);
if(w.second==0) break;
change(root,1,n,dfn[w.second]); t--; /*dfn*/
jump(w.second);
}
recover(root,1,n);
}
printf("%lld
",ans);
return 0;
}