题目链接:https://www.luogu.org/problemnew/show/P2680
原先写的题,忘了补题解了。。。咕咕咕~~今天重新写了一遍,现在补上。
题意就是一棵树,我们要使得一条边的权值为0来达成题目给定的几条路径的最大长度最小。
就是当我们做最优解问题的时候,如果不好做,可以考虑从使用枚举答案然后将其转化为判定性问题。而答案的枚举采用二分来加速。(当然,这种最小化最大值的问题一眼应该就能看出来是二分了)
然后呢。。。当然,因为要求路径之间的长度,设路径的两端点分别为(a,b),那么(dis(a,b)=dis(root,a)+dis(root,b)-2 imes dis(root,lca(a,b)))。
之后最大的问题就是这个check函数怎么写:
我们考虑二分出来一个长度(k),然后满足去掉一条边之后(也就是权值为0)所有的路径的长度都不超过它即为return true。当然如果不去掉任何一条边就可以满足条件的话肯定是return true。下面我们来处理如果有一些路径长度超过(k)的情况:
如果说有一些边超过(k),因为只能删去一个。所以我们肯定是要删去这些路径的交集上面的边才有可能满足return true的条件。而怎么判定一条边是不是在这些所有长度超出(k)的路径上呢,我们可以采取树上边差分的策略,开一个(tot)数组来记录((tot[i])表示节点(i)到它的父亲这条边),如果tot==这些路径的数量,就在交集上面了。
之后就没有什么了,代码如下:
#include<cstdio>
#include<algorithm>
#include<cstring>
#define MAXN 300010
using namespace std;
struct Edge{int to,nxt,w;}edge[MAXN<<1];
struct Node{int u,v,lcaa,diss;}node[MAXN<<1];
int summ,cnt=0,n,m,edge_number;
int num[MAXN],mi[MAXN],vis[MAXN];
int tot[MAXN],head[MAXN],deep[MAXN],dis[MAXN],lg[MAXN],fa[MAXN][32],dp[MAXN][32];
inline void add(int from,int to,int dis){
edge[++edge_number].to=to;
edge[edge_number].nxt=head[from];
edge[edge_number].w=dis;
head[from]=edge_number;
}
inline void init()
{
for(int i=1;i<=n;i++)
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
}
inline void search(int x,int f)
{
deep[x]=deep[f]+1;
num[++cnt]=x;
fa[x][0]=f;
for(int i=1;(1<<i)<=deep[x];i++)
fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=head[x];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(v==f) continue;
dis[v]=dis[x]+edge[i].w;
search(v,x);
}
}
inline int lca(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
while(deep[x]>deep[y])
x=fa[x][lg[deep[x]-deep[y]]-1];
if(x==y) return x;
for(int i=lg[deep[x]];i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
inline bool check(int mid){
int cnt=0,ans=0,maxx=0;
memset(tot,0,sizeof(tot));
for(int i=1;i<=m;i++){
if(node[i].diss>mid){
tot[node[i].u]++;tot[node[i].v]++;tot[node[i].lcaa]-=2;
maxx=max(maxx,node[i].diss);
cnt++;
}
}
if(cnt==0) return true;
for(int i=n;i>=1;i--)
tot[fa[num[i]][0]]+=tot[num[i]];
for(int i=2;i<=n;i++)
if(tot[i]==cnt&&maxx-(dis[i]-dis[fa[i][0]])<=mid)
return true;
return false;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n-1;i++){
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
add(x,y,w);
add(y,x,w);
summ+=w;
}
init();
dis[1]=0;
search(1,0);
for(int i=1;i<=m;i++){
scanf("%d%d",&node[i].u,&node[i].v);
node[i].lcaa=lca(node[i].u,node[i].v);
node[i].diss=dis[node[i].u]+dis[node[i].v]-2*dis[node[i].lcaa];
}
int left=0,right=summ;
int mid;
while(left<right){
mid=(left+right)>>1;
if(check(mid)) right=mid;
else left=mid+1;
}
printf("%d",left);
return 0;
}