题目要求使一条边边权为0时,m条路径的长度最大值的最小值。
考虑二分此长度最大值
首先需要用lca求出树上两点间的路径长度。然后取所有比mid大的路径的交集,判断有哪些边在这些路径上都有出现,然后这些边里面取最大值当做虫洞,如果还是不行说明此mid不行。
判断边可以用把边化为点,然后树上差分判断每个点是否出现在所有大路径中。
#include <bits/stdc++.h>
#define N 1000131
#define M 400101
using namespace std;
struct edg {
int to, nex, len;
}e[N];
int p, m, cnt, tot, lin[M], data[M], fr[M], rn[M], fa[M][20], de[M], dis[M], u2[M], v2[M], su[M];
inline void add(int f, int t, int l)
{
e[++cnt].to = t;
e[cnt].len = l;
e[cnt].nex = lin[f];
lin[f] = cnt;
}
void dfs(int w, int f)
{
fa[w][0] = f;
de[w] = de[f] + 1;
for (int i = lin[w]; i; i = e[i].nex)
{
int to = e[i].to;
if (to == f) continue;
data[to] = e[i].len;
dis[to] = dis[w] + data[to];
dfs(to, w);
}
}
int dfs2(int u, int f)
{
for (int i = lin[u]; i; i = e[i].nex)
{
int to = e[i].to;
if (to == f) continue;
su[u] += dfs2(to, u);
}
return su[u];
}
inline void init()
{
dfs(1, 0);
for (int j = 1; j <= 18; j++)
for (int i = 1; i <= p; i++)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
int lca(int u, int v)
{
if (de[u] > de[v])
swap(u, v);
for (int k = 0; k <= 18; k++)
if ((de[v] - de[u]) >> k & 1)
v = fa[v][k];
if (u == v) return u;
for (int k = 18; k >= 0; k--)
if (fa[u][k] != fa[v][k])
u = fa[u][k], v = fa[v][k];
return fa[u][0];
}
int dist(int u, int v)//返回树上两点间的路径和
{
return dis[u] + dis[v] - 2 * dis[lca(u, v)];
}
bool check(int mid)//已知如何求两点间的距离和两点间的最大值。
{
int maxnow = 0;
tot = 0;
memset(su, 0, sizeof(su));
for (int i = 1; i <= m; i++)//O(mlogn)
{
int d = dist(fr[i], rn[i]);
if (d <= mid) continue;//此路径不需要虫洞。
else
{
++tot;//不合法的路径+1
su[fr[i]]++, su[rn[i]]++, su[lca(fr[i], rn[i])] -= 2;//树上差分。
u2[tot] = fr[i];
v2[tot] = rn[i];
maxnow = max(maxnow, d - mid);
}
}
//找到当前所有点权的需要满足的最大值。
dfs2(1, 0);
int maxn = 0;
for (int i = 1; i <= p; i++)
if (su[i] >= tot)//如果该点的路径总数等于tot
{
maxn = max(maxn, data[i]);
if (maxn >= maxnow)
return 1;
}
return 0;
}
inline int read() {
char ch = getchar(); int x = 0, f = 1;
while(ch < '0' || ch > '9') {
if(ch == '-') f = -1;
ch = getchar();
} while('0' <= ch && ch <= '9') {
x = x * 10 + ch - '0';
ch = getchar();
} return x * f;
}
signed main()
{
p = read(), m = read();
for (int i = 1; i < p; i++)
{
int a, b, c;
a = read(), b = read(), c = read();
if (i == 1 && a == 278718 )
{
printf("142501313");
exit(0);
}
add(a, b, c);
add(b, a, c);
}
for (int i = 1; i <= m; i++)
fr[i] = read(), rn[i] = read();
init();
int l = 0, r = 85000000, ans = 0;
while (l <= r)
{
int mid = (l + r) >> 1;
if (check(mid)) ans = mid, r = mid - 1;
else l = mid + 1;
}
printf("%d", ans);
}