题目描述
给出一棵树和若干条直上直下的链,每条链有权值(非负)
用权值和尽量小的链覆盖树上所有的点
(n,m leq 1e6)
分析
猫老师课件上的题。
先考虑本问题在链上的形式,即最小权区间覆盖,这个问题有一个经典做法,即使用线段树优化DP。但是我们发现,由于树的形态,这个算法在树上不具有较强的扩展性(或许是因为博主太弱了没想到)。实际上,最小权区间覆盖问题也可以使用一个堆来解决。具体算法是:把所有区间按左端点排序,维护一个小根堆,堆中存的是一些二元组((c,d)),其中(c)是关键字,表示的是可以用(c)的代价覆盖([1,d])这段区间。从左往右扫所有区间,每扫到一个权值为(w)区间([l,r]),就把二元组((min_{in ext{heap}} {c}+w,r))加入堆中,当处理完所有左端点为(l)的区间后,删除堆中所有(d<l)的二元组。因为我们选择的区间不可能存在包含关系,所以算法正确性显然。
这个算法可以通过支持合并的数据结构扩展到树上,这里我们使用左偏树实现的可并堆。在树上多个子树的堆合并前,每个堆的所有元素要加上其他堆的(min c)之和。
代码
未经过对拍,不保证其正确性。
#include <bits/stdc++.h>
#define rin(i,a,b) for(register int i=(a);i<=(b);++i)
#define irin(i,a,b) for(register int i=(a);i>=(b);--i)
#define trav(i,a) for(register int i=head[a];i;i=e[i].nxt)
typedef long long LL;
using std::cin;
using std::cout;
using std::endl;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=1e6+5;
int n,m,ecnt,head[MAXN];
int dep[MAXN];
LL minf[MAXN];
int root[MAXN],tot;
struct Edge{
int to,nxt;
}e[MAXN<<1];
struct path{
int top,val;
};
std::vector<path> vec[MAXN];
struct leftist{
int ch[2];
int dis,dep;
LL dat,tag;
}lt[MAXN];
inline void add_edge(int bg,int ed){
++ecnt;
e[ecnt].to=ed;
e[ecnt].nxt=head[bg];
head[bg]=ecnt;
}
void dfs1(int x,int pre,int depth){
dep[x]=depth;
trav(i,x){
int ver=e[i].to;
if(ver==pre) continue;
dfs1(ver,x,depth+1);
}
}
#define lc lt[x].ch[0]
#define rc lt[x].ch[1]
inline void pushtag(int x,LL _kk){
lt[x].dat+=_kk;
lt[x].tag+=_kk;
}
inline void pushdown(int x){
if(!lt[x].tag) return;
if(lc) pushtag(lc,lt[x].tag);
if(rc) pushtag(rc,lt[x].tag);
lt[x].tag=0;
}
int merge(int x,int y){
if(!x||!y) return x+y;
pushdown(x);pushdown(y);
if(lt[x].dat>lt[y].dat) std::swap(x,y);
rc=merge(rc,y);
if(lt[lc].dis<lt[rc].dis) std::swap(lc,rc);
lt[x].dis=lt[rc].dis+1;
return x;
}
int del(int x){
pushdown(x);
return merge(lc,rc);
}
#undef lc
#undef rc
void dfs2(int x,int pre){
LL temp=0;
trav(i,x){
int ver=e[i].to;
if(ver==pre) continue;
dfs2(ver,x);
temp+=minf[ver];
}
rin(i,0,(int)vec[x].size()-1){
lt[++tot]=(leftist){0,0,1,vec[x][i].top,vec[x][i].val+temp,0};
root[x]=merge(root[x],tot);
}
trav(i,x){
int ver=e[i].to;
if(ver==pre) continue;
pushtag(root[ver],temp-minf[ver]);
root[x]=merge(root[x],root[ver]);
}
while(lt[root[x]].dep>dep[x]) root[x]=del(root[x]);
minf[x]=lt[root[x]].dat;
}
int main(){
n=read();
rin(i,2,n){
int u=read(),v=read();
add_edge(u,v);
add_edge(v,u);
}
dfs1(1,0,1);
m=read();
rin(i,1,m){
int u=read(),v=read(),w=read();
if(dep[u]<dep[v]) std::swap(u,v);
vec[u].push_back((path){dep[v],w});
}
dfs2(1,0);
printf("%lld
",minf[1]);
return 0;
}
/*
7
1 2
1 3
2 4
2 5
3 6
3 7
6
1 4 3
1 5 2
2 5 1
1 6 5
6 6 1
3 7 2
7
*/