题意:
多组输入。给定一棵树,每条边有两个可能的取值a和b,整棵树有k条边的值来自a,其余边的值来自b,问树的直径最小值为多少?
取值范围:k<=min(20,n-1),n<=20000,(sum n<=200000)
解法:
和直径有关,一般要用到树上dp
发现直接求解直径的最小值难以实现,考虑二分答案,检验答案是否可行。
二分树的直径,已知直径,如何检验是否可满足?注意到k<=20,可以采用k^2的背包进行状态转移。
建立dp数组(dp[n][k]),表示第n个点的子树中,有k条边取值来自a,点n到子树的叶子节点的最长路。(通常情况是这样,也有其他情况下文会说明)
对于点u,遍历到了v,边长度为a或b,枚举点u已经搜索到的子树中取a的个数和v子树中取a的个数,可以得到状态转移方程:
(dp[u][i+j+1]=min(dp[u][i+j+1],max(dp[u][i],dp[v][j]+a))
(dp[u][i+j]=min(dp[u][i+j],max(dp[u][i],dp[v][j]+b))
因为要检验mid是否可行,可以在转移之前得出此时u的子树的直径长度len=(dp[u][i]+dp[v][j]+a)或(dp[u][i]+dp[v][j]+b),只有在len<=mid时候才进行转移(因为此时对应了一种在u的子树上是可行的方案,如果不满足上述式子,这种方案就是不可行的,不需要更新)。如果一个子树没有一种可行的方案,那么mid必然就是不可满足的。
因为在进行背包过程时,u的子树是不包含v的,所以不能在循环中直接对dp数组进行更新,要另开一个temp数组存储;为了检测该点是否有可以满足的方案(即该点是否被更新),先将temp数组都初始化为mid+1,这样如果该点的一种方案是不可满足的,则对应的(dp[u][i])就被置为mid+1,在向上传递的过程中所有枚举到这种子方案的方案都将会是不可满足的(直径>mid),这种不可满足性会一直向上传递(这就是上文提到的dp值不为最长路的情况)。因此如果dfs结束后,(dp[1][k])不为mid+1,那么mid的情况必然就是可以满足的,反之是不可满足的。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=2e4+5;
const int maxm=4e5+5;
struct edge{
int next,v;
ll a,b;
}E[maxm];
int head[maxn],tot;
void addedge(int u,int v,ll a,ll b){
E[++tot]=edge{head[u],v,a,b};
head[u]=tot;
}
ll dp[maxn][25];
int sz[maxn];
int n,k;
void dfs(int u,int fa,ll lim){
sz[u]=0;
for(int i=0;i<=k;i++){
dp[u][i]=0;
}
for(int ed=head[u];ed;ed=E[ed].next){
int v=E[ed].v;
if(v==fa)continue;
dfs(v,u,lim);
int now=min(sz[u]+sz[v]+1,k);
ll temp[25];
for(int i=0;i<=now;i++){
temp[i]=lim+1;
}
for(int i=0;i<=sz[u];i++){
for(int j=0;j<=sz[v]&&i+j<=k;j++){
if(dp[u][i]+dp[v][j]+E[ed].a<=lim){
temp[i+j+1]=min(temp[i+j+1],max(dp[u][i],dp[v][j]+E[ed].a));
}
if(dp[u][i]+dp[v][j]+E[ed].b<=lim){
temp[i+j]=min(temp[i+j],max(dp[u][i],dp[v][j]+E[ed].b));
}
}
}
for(int i=0;i<=now;i++){
dp[u][i]=temp[i];
}
sz[u]=now;
}
}
void init(int n){
memset(head,0,sizeof(int)*(n+1));
tot=0;
}
int main () {
int T;
scanf("%d",&T);
while(T--){
scanf("%d%d",&n,&k);
init(n);
for(int i=1;i<=n-1;i++){
int u,v;
ll a,b;
scanf("%d%d%lld%lld",&u,&v,&a,&b);
addedge(u,v,a,b);
addedge(v,u,a,b);
}
ll l=1,r=2e13;//左开右闭
while(r-l>1){
ll mid=(l+r)>>1;
dfs(1,0,mid);
if(dp[1][k]<=mid){
r=mid;
}
else{
l=mid;
}
}
printf("%lld
",r);
}
}