题目链接
做法
[dep(x) + dep(y) - dep(LCA(x, y)) - dep'(LCA'(x, y))\\
= frac{1}{2} (dep(x) + dep(y) - 2dep(LCA(x, y)) + dep(x) + dep(y) - 2dep'(LCA'(x, y)))\\
= frac{1}{2}(dis(x, y) + dep(x) + dep(y) - 2dep'(LCA'(x, y)))
]
考虑对第一棵树边分治。设当前分治重心为 $ U, V $ ,选择 $ U $ 侧节点 $ X $ ,选择 $ V $ 侧节点 $ Y $ ,则令 $ X $ 为一类节点,贡献为 $ e1(X) = dep(X) + dis(V, X) $ ;令 $ Y $ 为二类节点,贡献为 $ e2(Y) = dep(Y) + dis(V, Y) $ 。枚举第二颗树的 $ LCA $ ,设 $ X, Y $ 在第二颗树中的 $ LCA $ 为 $ lca $ ,则对答案的贡献为 $ frac{1}{2}(e1(X) + e2(Y) - 2dep(lca)) $ 。由于需要正确的时间复杂度,所以需要对第二颗树建虚树进行 $ DP $ 。
不优秀的实现会导致时间复杂度为 $ O(n log^2 n) $ ,由于每次建虚树的时间复杂度应为 $ O(k) $ ,所以需要用 $ RMQ $ 实现 $ O(1) $ 的 $ LCA $ ;另外每次建虚树需要将点排序,将排序放在分治之前,然后按照分治将数列分成两段,每次建虚树直接调用数组(或者先分治下去再归并排序然后建虚树)。更改后时间复杂度为 $ O(n log n) $ 。
注意 $ x $ 可以与 $ y $ 相同,而边分治未考虑这一点,所以还要考虑 $ x = y $ 的情况。
#include <bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define per(i, a, b) for(int i = (a); i >= (b); i--)
#define pb push_back
#define mp make_pair
#define fst first
#define snd second
using namespace std;
typedef long long ll;
typedef pair<int, ll> pil;
const ll INF = 1e17;
const int N = 800010;
int n, m; ll ans = -INF, w[N], fw[N];
vector<pil> e1[N], e2[N]; vector<int> E[N]; int ar[N], len;
int dep[N], dfn[N], idx, a[N], st[20][N], lg[N]; ll Dep[N];
int cnt = 1, to[N + N], nxt[N + N], hed[N]; ll val[N + N]; bool used[N + N];
int size, rte, mn, sz[N], tot;
ll value[N], f1[N], f2[N]; int flag[N], sta[N], top;
template<typename T> void gi(T &x) {
x = 0; register char c = getchar(), pre = 0;
for(; c < '0' || c > '9'; pre = c, c = getchar());
for(; c >= '0' && c <= '9'; c = getchar()) x = x * 10ll + (c ^ 48);
if(pre == '-') x = -x;
}
inline void addedge(int x, int y, ll z) {
to[++cnt] = y, nxt[cnt] = hed[x], hed[x] = cnt, val[cnt] = z;
to[++cnt] = x, nxt[cnt] = hed[y], hed[y] = cnt, val[cnt] = z;
}
inline bool cmp(const int &x, const int &y) { return dfn[x] < dfn[y]; }
void getdfn(int u, int ff) {
dfn[u] = ++idx, a[idx] = u, dep[u] = dep[ff] + 1;
for(auto v : e2[u]) if(v.fst != ff) Dep[v.fst] = Dep[u] + v.snd, getdfn(v.fst, u), a[++idx] = u;
}
inline int LCA(int x, int y) {
x = dfn[x], y = dfn[y]; if(x > y) swap(x, y); int k = y - x + 1;
return dep[st[lg[k]][x]] <= dep[st[lg[k]][y - (1 << lg[k]) + 1]] ? st[lg[k]][x] : st[lg[k]][y - (1 << lg[k]) + 1];
}
int build(int l, int r) {
if(l > r) return 0; if(l == r) return ar[l];
int mid = (l + r) >> 1, u = ++m, ls = build(l, mid), rs = build(mid + 1, r);
if(ls) addedge(u, ls, fw[ls]); if(rs) addedge(u, rs, fw[rs]); return u;
}
void rebuild(int u, int ff) {
len = 0; for(auto v : e1[u]) if(v.fst != ff) ar[++len] = v.fst, fw[v.fst] = v.snd;
int mid = (1 + len) >> 1, ls = build(1, mid), rs = build(mid + 1, len);
if(ls) addedge(u, ls, fw[ls]); if(rs) addedge(u, rs, fw[rs]);
for(auto v : e1[u]) if(v.fst != ff) w[v.fst] = w[u] + v.snd, rebuild(v.fst, u);
}
void getrt(int u, int ff, int ed) {
sz[u] = 1;
for(int i = hed[u]; i; i = nxt[i]) if(to[i] != ff && !used[i])
getrt(to[i], u, i), sz[u] += sz[to[i]];
if(abs(size - 2 * sz[u]) < mn) mn = abs(size - 2 * sz[u]), rte = ed;
}
void Find(int u, int ff, int opt, ll d) {
if(u <= n) value[u] = d + w[u], flag[u] = opt;
for(int i = hed[u]; i; i = nxt[i]) if(to[i] != ff && !used[i]) Find(to[i], u, opt, d + val[i]);
}
void Dfs(int u) {
f1[u] = f2[u] = -INF;
if(flag[u] == 1) f1[u] = max(f1[u], value[u]);
if(flag[u] == 2) f2[u] = max(f2[u], value[u]);
for(auto v : E[u]) {
Dfs(v);
ans = max(ans, max(f1[u] + f2[v], f2[u] + f1[v]) - Dep[u] - Dep[u]);
f1[u] = max(f1[u], f1[v]), f2[u] = max(f2[u], f2[v]);
}
E[u].clear(), flag[u] = 0;
}
void Solve(vector<int> p) {
sta[top = 1] = 1;
for(auto v : p) {
if(v == 1) continue; int lca = LCA(v, sta[top]);
if(lca == sta[top]) { sta[++top] = v; continue; }
for(; top > 1 && dep[sta[top - 1]] >= dep[lca]; --top) E[sta[top - 1]].pb(sta[top]);
if(lca != sta[top]) E[lca].pb(sta[top]), sta[top] = lca;
sta[++top] = v;
}
for(; top > 1; --top) E[sta[top - 1]].pb(sta[top]); Dfs(1);
}
void solve(int u, vector<int> p) {
if(u == -1 || used[u]) return ;
used[u] = used[u ^ 1] = 1; int rt1 = to[u], rt2 = to[u ^ 1], t1, t2;
Find(rt1, rt2, 1, val[u]), Find(rt2, rt1, 2, 0);
vector<int> ls, rs; for(auto v : p) flag[v] == 1 ? ls.pb(v) : rs.pb(v);
Solve(p);
size = ls.size(), mn = m + 1, getrt(rt1, rt2, -1), solve(rte, ls);
size = rs.size(), mn = m + 1, getrt(rt2, rt1, -1), solve(rte, rs);
}
int main() {
gi(n), m = n;
rep(i, 2, n) { int x, y; ll z; gi(x), gi(y), gi(z), e1[x].pb(mp(y, z)), e1[y].pb(mp(x, z)); }
rep(i, 2, n) { int x, y; ll z; gi(x), gi(y), gi(z), e2[x].pb(mp(y, z)), e2[y].pb(mp(x, z)); }
getdfn(1, 0); rep(i, 1, idx) st[0][i] = a[i];
lg[0] = -1; rep(i, 1, idx) lg[i] = lg[i >> 1] + 1;
for(int j = 1; (1 << j) <= idx; j++)
for(int i = 1; i <= idx - (1 << j) + 1; i++)
st[j][i] = dep[st[j - 1][i]] <= dep[st[j - 1][i + (1 << j - 1)]] ? st[j - 1][i] : st[j - 1][i + (1 << j - 1)];
rebuild(1, 0);
vector<int> p; rep(i, 1, n) p.pb(i); sort(p.begin(), p.end(), cmp);
size = m, mn = m + 1, getrt(1, 0, -1), solve(rte, p), ans /= 2;
rep(i, 1, n) ans = max(ans, w[i] - Dep[i]);
printf("%lld
", ans);
return 0;
}