题面
给出两颗树\(A,B\),求\(\max(A_{dis}(x,y) + dep_{A_{lca}\ }(x,y) + dep_{B_{\ lca}\ }\ (x,y))\)
解法
考虑改写柿子答案为\(\frac{1}{2}\max(A_{dis}(x,y) + dep_x + dep_y - 2 dep_{B_{\ lca}\ }(x,y))\)
考虑淀粉质,把分治中点统领范围的点拉出来在\(B\)上建出虚树,关键点点权赋值为\(dis(mid,x) + dep(x)\),同一个子树赋值同一个颜色,在虚树上\(dp\)不同颜色的\(\max val(x) + val(y) - 2dep(lca)\)即可。
点击查看代码
//晦暗的宇宙,我们找不到光,看不见尽头,但我们永远都不会被黑色打倒。——Quinn葵因
#include <bits/stdc++.h>
#define ll long long
#define N 1000000
using std::vector;
using std::pair;
int n;
#define pil pair<int,ll>
#define mp std::make_pair
vector<pil>A[N], B[N]; //tree
ll depa[N], depb[N];
int dep[N];
inline void dfsa(int u, int fa) {
// std::cout<<u<<" "<<fa<<" "<<depa[u]<<"\n";
for (auto it : A[u]) {
int x = it.first;
ll v = it.second;
if (x == fa)
continue;
depa[x] = depa[u] + v;
dfsa(x, u);
}
}
int F[N][20];
int dfn[N];
int cnt;
inline void dfsb(int u, int fa) {
F[u][0] = fa;
dep[u] = dep[fa] + 1;
for (int i = 1; i <= 19; ++i)
F[u][i] = F[F[u][i - 1]][i - 1];
dfn[u] = ++cnt;
for (auto it : B[u]) {
int x = it.first;
ll v = it.second;
if (x == fa)
continue;
depb[x] = depb[u] + v;
dfsb(x, u);
}
}
int sum, siz[N], root;
int maxn[N];
ll val[N];
int c[N];
int vis[N];
inline void find(int u, int fa) {
siz[u] = 1;
maxn[u] = 1;
for (auto it : A[u]) {
int v = it.first;
if (v == fa || vis[v])
continue;
find(v, u);
siz[u] += siz[v];
maxn[u] = std::max(maxn[u], siz[v]);
}
maxn[u] = std::max(maxn[u], sum - siz[u]);
// std::cout<<"FIND "<<u<<" "<<maxn[u]<<"\n";
if (maxn[u] < maxn[root])
root = u;
}
vector<int>P;
inline void dis(int u, int fa) {
// std::cout<<"DIS "<<u<<" "<<fa<<" "<<val[u]<<"\n";
c[u] = c[fa];
for (auto it : A[u]) {
int v = it.first;
ll vi = it.second;
if (vis[v] || v == fa)
continue;
val[v] = val[u] + vi;
dis(v, u);
}
val[u] = val[u] + depa[u];
P.push_back(u);
}
inline int lca(int x, int y) {
// std::cout<<"LCA "<<x<<" "<<y<<"\n";
if (dep[x] < dep[y])
std::swap(x, y);
for (int i = 19; i >= 0; --i) {
if (dep[F[x][i]] >= dep[y])
x = F[x][i];
// std::cout<<x<<"\n";
}
if (x == y)
return x;
for (int i = 19; i >= 0; --i) {
if (F[x][i] != F[y][i])
x = F[x][i], y = F[y][i];
// std::cout<<"UP "<<x<<" "<<y<<"\n";
}
return F[x][0];
}
inline bool cmp(int x, int y) {
return dfn[x] < dfn[y];
}
ll f[N][2];
bool key[N];
int Fi[N];
ll ans = -1e18;
#define pii pair<ll,int>
vector<pii>M;
inline void merge(int x, int y) { //y -> x
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
if (c[f[x][i]] != c[f[y][j]])
ans = std::max(ans, val[f[x][i]] + val[f[y][j]] - 2 * depb[x]);
// bool tag = 0;
//
// if (val[f[y][0]] < val[f[y][1]])
// std::swap(f[y][0], f[y][1]);
//
// if (val[f[x][0]] < val[f[x][1]])
// std::swap(f[x][0], f[x][1]);
//
// if (val[f[y][0]] > val[f[x][0]]) {
// if (c[f[y][0]] != c[f[x][0]])
// f[x][1] = f[x][0];
//
// f[x][0] = f[y][0];
// tag = 1;
// } else {
// if (val[f[y][0]] > val[f[x][1]] && c[f[y][0]] != c[f[x][0]]) {
// f[x][1] = f[y][0];
// tag = 1;
// }
// }
//
// if (!tag) {
// if (val[f[y][1]] > val[f[x][0]]) {
// if (c[f[y][1]] != c[f[x][0]])
// f[x][1] = f[x][0];
//
// f[x][0] = f[y][1];
// } else {
// if (val[f[y][1]] > val[f[x][1]] && c[f[y][1]] != c[f[x][0]]) {
// f[x][1] = f[y][1];
// }
// }
// }
// std::cout<<"CAO NI MA DE MERGE"<<"\n";
// std::cout<<x<<" "<<y<<"\n";
// std::cout<<f[x][0]<<" "<<c[f[x][0]]<<" "<<val[f[x][0]]<<"\n";
// std::cout<<f[x][1]<<" "<<c[f[x][1]]<<" "<<val[f[x][1]]<<"\n";
// std::cout<<f[y][0]<<" "<<c[f[y][0]]<<" "<<val[f[y][0]]<<"\n";
// std::cout<<f[y][1]<<" "<<c[f[y][1]]<<" "<<val[f[y][1]]<<"\n";
M.clear();
M.push_back(mp(-val[f[x][0]], f[x][0]));
M.push_back(mp(-val[f[x][1]], f[x][1]));
M.push_back(mp(-val[f[y][0]], f[y][0]));
M.push_back(mp(-val[f[y][1]], f[y][1]));
std::sort(M.begin(), M.end());
// puts("CAO NI MA DE WO sort GET");
// for(int i = 0;i <= 3;++i)
// std::cout<<-M[i].first<<" "<<M[i].second<<"\n";
f[x][0] = M[0].second;
for (int i = 1; i <= 3; ++i) {
if (c[M[i].second] != c[f[x][0]]) {
f[x][1] = M[i].second;
// std::cout<<f[x][0]<<" "<<c[f[x][0]]<<" "<<val[f[x][0]]<<"\n";
// std::cout<<f[x][1]<<" "<<c[f[x][1]]<<" "<<val[f[x][1]]<<"\n";
return ;
}
}
}
inline int iread(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline ll lread(){
ll s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline void build() {
val[0] = -1e18;
c[0] = -1;
// puts("FUCK BUILE THE WEAK TREE");
std::sort(P.begin(), P.end(), cmp);
for (int i = 0; i < P.size(); ++i)
key[P[i]] = 1;/*,std::cout<<P[i]<<" ";
puts(""); */
int k = P.size();
for (int i = 1; i < k; ++i)
P.push_back(lca(P[i], P[i - 1]));
// std::cout<<"HELP "<<P[i]<<" "<<P[i - 1]<<" "<<lca(P[i],P[i - 1])<<"\n";
std::sort(P.begin(), P.end(), cmp);
P.erase(unique(P.begin(), P.end()), P.end());
for (int i = 1; i < P.size(); ++i)
Fi[P[i]] = lca(P[i], P[i - 1]);
for (int i = 0; i < P.size(); ++i) {
if (!key[P[i]])
f[P[i]][0] = f[P[i]][1] = 0;
else
f[P[i]][0] = P[i], f[P[i]][1] = 0;
}
for (int i = P.size() - 1; i >= 1; --i) {
int u = P[i];
merge(Fi[u], u);
// std::cout<<"GLASS "<<u<<" "<<Fi[u]<<"\n";
// std::cout<<f[u][0]<<" "<<c[f[u][0]]<<" "<<val[f[u][0]]<<"\n";
// std::cout<<f[u][1]<<" "<<c[f[u][1]]<<" "<<val[f[u][1]]<<"\n";
// ans = std::max(ans, val[f[u][0]] + val[f[u][1]] - 2 * depb[u]);
// std::cout<<ans<<"\n";
}
// std::cout<<"GLASS "<<P[1]<<" "<<Fi[P[1]]<<"\n";
// std::cout<<f[P[0]][0]<<" "<<c[f[P[0]][0]]<<" "<<val[f[P[0]][0]]<<"\n";
// std::cout<<f[P[0]][1]<<" "<<c[f[P[0]][1]]<<" "<<val[f[P[0]][1]]<<"\n";
// ans = std::max(ans, val[f[P[0]][0]] + val[f[P[0]][1]] - 2 * depb[P[0]]);
for (int i = 0; i < P.size(); ++i)
key[P[i]] = 0;
}
inline void solve(int u) {
// if (vis[u])
// return ;
// std::cout<<"DEL "<<u<<"\n";
vis[u] = 1;
find(u, 0);
val[u] = depa[u];
c[u] = u;
P.clear();
P.push_back(u);
// std::cout<<"FUCK DIS"<<"\n";
for (auto it : A[u]) {
int v = it.first;
ll vi = it.second;
if (vis[v])
continue;
val[v] = vi;
c[v] = v;
dis(v, v);
}
build();
for (auto it : A[u]) {
int v = it.first;
sum = siz[v], root = 0;
if (vis[v])
continue;
find(v, 0);
// std::cout<<u<<" FUCK "<<v<<" = "<<root<<"\n";
solve(root);
}
}
signed main() {
// freopen("q.in","r",stdin);
// freopen("q.out","w",stdout);
// scanf("%d", &n);
n = iread();
for (int i = 1; i < n; ++i) {
int x = iread(), y = iread();
ll v = lread();
A[x].push_back(mp(y, v));
A[y].push_back(mp(x, v));
}
for (int i = 1; i < n; ++i) {
int x = iread(), y = iread();
ll v = lread();
B[x].push_back(mp(y, v));
B[y].push_back(mp(x, v));
}
dfsa(1, 0);
dfsb(1, 0);
maxn[0] = n * 2;
root = 0;
sum = n;
find(1, 0);
solve(root);
ans = ans / 2;
for (int i = 1; i <= n; ++i)
ans = std::max(ans, 2ll * depa[i] - depa[i] - depb[i]);
std::cout << ans << "\n";
}
/*
6
1 2 2
1 3 0
2 4 1
2 5 -7
3 6 0
1 2 -1
2 3 -1
2 5 3
2 6 -2
3 4 8
*/