BZOJ原题链接
洛谷原题链接
用(LCA)初始化出所有运输计划的原始时间,因为答案有单调性,所以二分答案,然后考虑检验答案。
很容易想到将所有超出当前二分的答案的运输计划所经过的路径标记,在这些运输计划都经过的边中的权值最大的这条边上建立虫洞,如果能使得所有运输计划中需要时间最多的那个计划能在当前二分的答案之内完成,那么这个答案就是可行的。
然后考虑如何快速找到这些运输计划都经过的边,本质上我们要求的就是一条边被覆盖几次,所以考虑树上差分。
定义(dif[i])表示点(i)到它父亲节点的这条边被覆盖几次,对于一个计划((x,y)),我们将(dif[x]++,dif[y]++,dif[LCA(x,y)]-2),最后将点(i)的子树中所有的(dif)累加到(dif[i])即可。
不过这题挺卡常的,所以要注意常数。
- 不在求(LCA)的同时求经过路径长度,而是用该表达式(dis[x]+dis[y]-dis[LCA(x,y)] imes 2)((dis[x])表示点(x)到根的距离)。
- 缩小二分答案的范围,上界定为最长链的长度,即所有运输计划中需要时间最多的那个计划,下界定为最长链减去这条链上权值最大的边。
- 计算(dif)不用(dfs),而是在初始化倍增(LCA)的时候求出时间戳,然后枚举时间戳来累加并判断。
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
using namespace std;
const int N = 3e5 + 10;
const int M = 19;
struct ts {
int x, y, t, po;
};
ts a[N];
int fi[N], di[N << 1], ne[N << 1], da[N << 1], f[N][M], g[N], dif[N], dis[N], de[N], dfn[N], l, gn, n, m, ti;
inline int re()
{
int x = 0;
char c = getchar();
bool p = 0;
for (; c < '0' || c > '9'; c = getchar())
p |= c == '-';
for (; c >= '0' && c <= '9'; c = getchar())
x = x * 10 + c - '0';
return p ? -x : x;
}
inline void add(int x, int y, int z)
{
di[++l] = y;
da[l] = z;
ne[l] = fi[x];
fi[x] = l;
}
inline void sw(int &x, int &y)
{
int z = x;
x = y;
y = z;
}
inline int maxn(int x, int y)
{
return x > y ? x : y;
}
bool comp(ts x, ts y)
{
return x.t > y.t;
}
void dfs(int x)
{
int i, y;
dfn[++ti] = x;
for (i = 1; i <= gn; i++)
{
f[x][i] = f[f[x][i - 1]][i - 1];
if (!f[x][i])
break;
}
for (i = fi[x]; i; i = ne[i])
if ((y = di[i]) ^ f[x][0])
{
f[y][0] = x;
dis[y] = dis[x] + (g[y] = da[i]);
de[y] = de[x] + 1;
dfs(y);
}
}
int lca(int x, int y)
{
int i;
if (de[x] > de[y])
sw(x, y);
for (i = gn; ~i; i--)
if (de[f[y][i]] >= de[x])
y = f[y][i];
if (!(x ^ y))
return x;
for (i = gn; ~i; i--)
if (f[x][i] ^ f[y][i])
{
x = f[x][i];
y = f[y][i];
}
return f[x][0];
}
int fin(int x, int y, int z)
{
int ma = 0;
for (; x ^ z; x = f[x][0])
ma = maxn(ma, g[x]);
for (; y ^ z; y = f[y][0])
ma = maxn(ma, g[y]);
return ma;
}
bool judge(int mid)
{
int i, k = a[1].t - mid, o = 0, x;
memset(dif, 0, sizeof(dif));
for (i = 1; i <= m; i++)
{
if (a[i].t <= mid)
break;
dif[a[i].x]++;
dif[a[i].y]++;
dif[a[i].po] -= 2;
o++;
}
if (!o)
return true;
for (i = n; i; i--)
{
x = dfn[i];
dif[f[x][0]] += dif[x];
if (!(dif[x] ^ o) && g[x] >= k)
return true;
}
return false;
}
int main()
{
int i, x, y, z, l, r, mid;
n = re();
m = re();
for (i = 1; i < n; i++)
{
x = re();
y = re();
z = re();
add(x, y, z);
add(y, x, z);
}
gn = log2(n);
de[1] = 1;
dfs(1);
for (i = 1; i <= m; i++)
{
a[i].x = re();
a[i].y = re();
a[i].po = lca(a[i].x, a[i].y);
a[i].t = dis[a[i].x] + dis[a[i].y] - (dis[a[i].po] << 1);
}
sort(a + 1, a + m + 1, comp);
r = a[1].t;
l = r - fin(a[1].x, a[1].y, a[1].po);
while (l < r)
{
mid = (l + r) >> 1;
if (judge(mid))
r = mid;
else
l = mid + 1;
}
printf("%d", r);
return 0;
}