【BZOJ4297】[PA2015]Rozstaw szyn
Description
给定一棵有n个点,m个叶子节点的树,其中m个叶子节点分别为1到m号点,每个叶子节点有一个权值r[i]。你需要给剩下n-m个点各指定一个权值,使得树上相邻两个点的权值差的绝对值之和最小。
Input
第一行包含两个正整数n,m(2<=n<=500000,1<=m<=n),分别表示点数和叶子数。
接下来n-1行,每行两个正整数u,v(1<=u,v<=n),表示u与v之间有一条边。
接下来m行,每行一个正整数,依次为r[1],r[2],...,r[m](1<=r[i]<=500000),表示每个叶子的权值。
Output
输出一个整数,即树上相邻两个点的权值差的绝对值之和的最小值。
Sample Input
1 5
2 5
3 6
4 6
5 6
5
10
20
40
Sample Output
题解:思路同BZOJ1304,咱们先来证几个结论:
1.我们从下往上逐层贪心,每次选择一个点的取值范围时,只保证它与它的儿子之间差的绝对值之和最小,而不考虑它的父亲。这样为什么是对的呢?假如x的最优值为v,我们为了使它的父亲更优,将x的取值改为v+d,那么x与x父亲之间的差会减小d,但 x的所有值<=v的儿子 与x之间的差都增加了d。具体地,如果x有a个儿子,那么增加量至少是d。显然是没有一开始优的。
2.以哪个非叶子节点为根进行DP,最后得到的答案都是一样的。假如当前根为x,x的儿子是y。那么如果x的最优取值区间被y包含,相当于x和y之间的差可以为0,那么如果把y当成根,则y的取值区间显然也会被x包含(不要问为什么显然~)。否则我们不考虑x-y这条边,x的取值范围是[l,r],那么在考虑y的贡献后x的取值范围只可能是[...,l]或[r,...],即其他点对x的影响可视为不变,那么只需要最后加上x-y的贡献即可。把y当根也是同理,所以将那个点当成根答案都是一样的。
所以具体做法:随便找一个点当根进行DP,然后用每个点的儿子的最优取值区间来得到当前点的最优取值区间。具体地,我们将x的所有儿子的最优取值区间的左右端点放到一起排序,然后取中间的那段即可。
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn=500010; typedef long long ll; int n,m,cnt; ll ans; int to[maxn<<1],next[maxn<<1],head[maxn],l[maxn],r[maxn],p[maxn<<1]; inline void add(int a,int b) { to[cnt]=b,next[cnt]=head[a],head[a]=cnt++; } void dfs(int x,int fa) { if(x<=m) return ; int i,tot=0; for(i=head[x];i!=-1;i=next[i]) if(to[i]!=fa) dfs(to[i],x); for(i=head[x];i!=-1;i=next[i]) if(to[i]!=fa) p[++tot]=l[to[i]],p[++tot]=r[to[i]]; sort(p+1,p+tot+1); l[x]=p[tot>>1],r[x]=p[(tot>>1)+1]; for(i=head[x];i!=-1;i=next[i]) if(to[i]!=fa&&(r[to[i]]<l[x]||l[to[i]]>l[x])) ans+=min(abs(l[to[i]]-l[x]),abs(r[to[i]]-l[x])); } inline int rd() { int ret=0,f=1; char gc=getchar(); while(gc<'0'||gc>'9') {if(gc=='-') f=-f; gc=getchar();} while(gc>='0'&&gc<='9') ret=ret*10+gc-'0',gc=getchar(); return ret*f; } int main() { //freopen("bz4297.in","r",stdin); n=rd(),m=rd(); int i,j,a,b; memset(head,-1,sizeof(head)); for(i=1;i<n;i++) a=rd(),b=rd(),add(a,b),add(b,a); for(i=1;i<=m;i++) l[i]=r[i]=rd(); if(n==m) { for(i=1;i<=n;i++) for(j=head[i];j!=-1;j=next[j]) ans+=abs(l[to[j]]-l[i]); printf("%lld",ans>>1); return 0; } dfs(n,0); printf("%lld",ans); return 0; }