运输计划(二分答案+lca)
有一颗n个节点的树,上面有m条路径。现在可以把一条边的长度变成零,问所有路径长度最大值的最小是多少。(n,m≤3e5)。
据说求最大值的最小值,或者最小值的最大值都要二分答案?二分路径长度的最大值l,那么我们必须修改一条边,使得所有路径长度比l大的路径,长度变为小于l。所以我们找到的需要更改的边,必须被每一个比l大的路径经过。那么如何找到可能被更改的边呢?树上差分即可。对于每条路径,在其两个端点处打上+1标记,在lca处打上-2标记。然后统计该节点子树的和,若和为路径条数,说明所有路径都经过它上面那条边,所以那条边可以计入计算。这样,找到符号条件且权值最大的边,如果最长的路径减去它,小于l,那么说明这条边权值变成零,所有路径的长度最大值可以小于l,那么l可以更小。否则l需要更大。
#include <cstdio>
#include <cstring>
#include <algorithm>
const int maxn=3e5+5, maxlog=20;
class Graph{
public:
struct Edge{
int to, next, v; Graph *bel;
Edge& operator ++(){
return *this=bel->edge[next]; }
inline int operator *(){ return to; }
};
void addedge(int x, int y, int v){
Edge &e=edge[++cntedge];
e.to=y; e.next=fir[x]; e.v=v;
e.bel=this; fir[x]=cntedge;
}
Edge& getlink(int x){ return edge[fir[x]]; }
private:
int cntedge, fir[maxn];
Edge edge[maxn*2];
};
struct Mark{
void set(int x, int y, int z){
pos=x; l=y; type=z; }
int pos, l, type;
}tmp, mark[maxn*3];
bool cmp(const Mark &a, const Mark &b){
return a.l<b.l;
}
Graph g;
int n, m, cntmark, maxlen, maxedge, num, dep[maxn], add[maxn];
int p[maxn][maxlog], dis[maxn][maxlog], fav[maxn];
void dfs(int now, int par){
Graph::Edge e=g.getlink(now);
p[now][0]=par;
for (int i=1; i<maxlog; ++i)
p[now][i]=p[p[now][i-1]][i-1];
dis[now][0]=fav[now];
for (int i=1; i<maxlog; ++i)
dis[now][i]=dis[now][i-1]+dis[p[now][i-1]][i-1];
for (; *e; ++e){
if (*e==par) continue;
dep[*e]=dep[now]+1;
fav[*e]=e.v;
dfs(*e, now);
}
}
void solve(int x, int y){
int tx=x, ty=y, len=0, lca;
if (dep[x]>dep[y]) std::swap(x, y);
for (int i=maxlog-1; i>=0; --i)
if (dep[p[y][i]]>=dep[x]){
len+=dis[y][i];
y=p[y][i];
}
for (int i=maxlog-1; i>=0; --i)
if (p[x][i]!=p[y][i]){
len+=dis[x][i]+dis[y][i];
x=p[x][i]; y=p[y][i];
}
if (x!=y){
len+=dis[x][0]+dis[y][0];
x=p[x][0];
} lca=x;
mark[cntmark++].set(tx, len, 1);
mark[cntmark++].set(ty, len, 1);
mark[cntmark++].set(lca, len, -2);
if (len>maxlen) maxlen=len;
}
int dfs2(int now, int par){
Graph::Edge e=g.getlink(now); int cnt=0;
for (; *e; ++e){
if (*e==par) continue;
cnt+=dfs2(*e, now);
}
cnt+=add[now];
if (cnt==num) maxedge=std::max(maxedge, fav[now]);
return cnt;
}
int main(){
scanf("%d%d", &n, &m);
int x, y, t, L, R, mid;
for (int i=1; i<n; ++i){
scanf("%d%d%d", &x, &y, &t);
g.addedge(x, y, t); g.addedge(y, x, t);
}
dep[1]=1; dfs(1, 0);
for (int i=0; i<m; ++i){
scanf("%d%d", &x, &y); solve(x, y); }
std::sort(mark, mark+cntmark, cmp);
L=0; R=maxlen;
while (L<R){
mid=(L+R)>>1; tmp.l=mid;
memset(add, 0, sizeof(add));
x=std::upper_bound(mark, mark+cntmark, tmp
, cmp)-mark;
num=(cntmark-x)/3;
for (int i=x; i<cntmark; ++i)
add[mark[i].pos]+=mark[i].type;
maxedge=0; dfs2(1, 0);
if (maxlen-maxedge<=mid) R=mid;
else L=mid+1;
}
printf("%d
", L);
return 0;
}