给定三棵带边权的树(T1,T2,T3),大小都为(n),要求(max{dist1(i,j)+dist2(i,j)+dist3(i,j)})。(nleq 10^5)。
考虑两棵树怎么做。有一种做法就是给(T2)每个点都新挂一个点,边权为(T1)中这个点到根的距离。然后枚举(T1)中的(LCA)统计路径,维护子树内任意两点在(T2)中距离的最大值和这两个点的编号(其实就是虚树的直径),合并两棵子树时统计答案,最长的路径肯定是从一棵子树的虚树直径的一端连向另一棵子树的一端。合并的时候就类似地,要么是原来的直径要么就是一个直径的一端连向另一个直径的一端。算距离用(RMQ)求(LCA)可以(mathcal{O}(1)),所以复杂度是(mathcal{O}(n))。
三棵树呢?对第一棵树进行边分治,就能每次把子树大小减半,然后对当前分治的树在第二棵树上建虚树,按照上面的做法做即可。至于(mathcal{O}(n))建虚树,可以参考https://www.cnblogs.com/akura/p/14423740.html这题的方法,这样复杂度就是(mathcal{O}nlog n)了。
代码巨难写。
#include<bits/stdc++.h>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i, a, b) for(int i = (a), ed = (b); i <= ed; ++i)
#define fb(i, a, b) for(int i = (a), ed = (b); i >= ed; --i)
#define go(u) for(int i = head[u]; ~i; i = e[i].nxt)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef cn int cint;
typedef long long LL;
typedef pair<int, LL> pr;
typedef pair<int, int> pr2;
il void rd(int &x){
x = 0;
rg int f(1); rg char c(gc);
while(c < '0' || '9' < c){if(c == '-')f = -1; c = gc;}
while('0' <= c && c <= '9')x = (x<<1)+(x<<3)+(c^48), c = gc;
x *= f;
}
cint maxn = 100010, maxm = 200010;
int n;
vector<pr> t1[maxn], t2[maxn], t3[maxm];
LL ans;
int m;
struct edge{
int to, nxt;
LL dis;
}e[maxm<<1];
int head[maxm], k;
il void add(cint &u, cint &v, cn LL &w){e[k] = (edge){v, head[u], w}, head[u] = k++;}
il void add2(cint &u, cint &v, cn LL &w){add(u, v, w), add(v, u, w);}
LL dis2[maxn], dis3[maxm];
int lg[maxm<<1];
int dep2[maxn], elr2[maxn<<1], fir2[maxn], fnl2[maxn], len2;
int dep3[maxm], elr3[maxm<<1], fir3[maxm], fnl3[maxm], len3;
pr2 mn2[2][maxn<<1][18], mn3[2][maxm<<1][19];
int siz[maxm], dfn[maxm], lst[maxm], vrt[maxn], tmp[2][maxn], stk[maxn], tp, rt, typ[maxn], f[maxn][2][2];
bool mark[maxm<<1];
LL dis[maxm], val[maxn], g[maxn][2];
vector<int> to[maxn];
il void read(vector<pr> *t){
rg int u, v;
rg LL w;
fp(i, 2, n){
rd(u), rd(v), scanf("%lld", &w);
t[u].pb(mp(v, w)), t[v].pb(mp(u, w));
}
}
void rebuild(int u, int pre){
int lst = 0;
for(auto &x : t1[u])if(x.fi^pre){
if(!lst)add2(u, x.fi, x.se), lst = u;
else ++m, add2(lst, m, 0), add2(m, x.fi, x.se), lst = m;
rebuild(x.fi, u);
}
}
void getdis2(int u, int pre){
elr2[++len2] = u, fir2[u] = len2, dep2[u] = dep2[pre]+1;
for(auto &x : t2[u])if(x.fi^pre)dis2[x.fi] = dis2[u]+x.se, getdis2(x.fi, u), elr2[++len2] = u;
fnl2[u] = len2;
}
void getdis3(int u, int pre){
elr3[++len3] = u, fir3[u] = len3, dep3[u] = dep3[pre]+1;
for(auto &x : t3[u])if(x.fi^pre)dis3[x.fi] = dis3[u]+x.se, getdis3(x.fi, u), elr3[++len3] = u;
fnl3[u] = len3;
}
il void init2(){
fp(i, 1, len2)mn2[0][i][0] = mn2[1][i][0] = mp(dep2[elr2[i]], elr2[i]);
for(rg int j = 1, s = 1; j <= 17; ++j, s <<= 1){
fp(i, 1, len2-s*2+1)mn2[0][i][j] = min(mn2[0][i][j-1], mn2[0][i+s][j-1]);
fb(i, len2, s*2)mn2[1][i][j] = min(mn2[1][i][j-1], mn2[1][i-s][j-1]);
}
}
il void init3(){
fp(i, 1, len3)mn3[0][i][0] = mn3[1][i][0] = mp(dep3[elr3[i]], elr3[i]);
for(rg int j = 1, s = 1; j <= 18; ++j, s <<= 1){
fp(i, 1, len3-s*2+1)mn3[0][i][j] = min(mn3[0][i][j-1], mn3[0][i+s][j-1]);
fb(i, len3, s*2)mn3[1][i][j] = min(mn3[1][i][j-1], mn3[1][i-s][j-1]);
}
}
il int getlca2(int u, int v){
if(fir2[u] > fir2[v])swap(u, v);
rg int len = fir2[v]-fir2[u]+1;
return min(mn2[0][fir2[u]][lg[len]], mn2[1][fir2[v]][lg[len]]).se;
}
il int getlca3(int u, int v){
if(fir3[u] > fir3[v])swap(u, v);
rg int len = fir3[v]-fir3[u]+1;
return min(mn3[0][fir3[u]][lg[len]], mn3[1][fir3[v]][lg[len]]).se;
}
il bool cmp(cint &u, cint &v){return fir2[u] < fir2[v];}
void getrt(int u, int pre, cint &n, int &edg, int &mn){
siz[u] = 1;
go(u)if(i != pre && !mark[i])getrt(e[i].to, i^1, n, edg, mn), siz[u] += siz[e[i].to];
if(max(siz[u], n-siz[u]) < mn)mn = max(siz[u], n-siz[u]), edg = pre;
}
void dfs(int u, int pre, int &tot){
dfn[u] = ++tot, siz[u] = 1;
go(u)if(e[i].to != pre && !mark[i]){
dis[e[i].to] = dis[u]+e[i].dis;
dfs(e[i].to, u, tot);
siz[u] += siz[e[i].to];
}
lst[u] = tot;
}
il void link(cint &u, cint &v){
if(dep2[u] < dep2[rt] || !rt)rt = u;
to[u].pb(v), g[u][0] = g[u][1] = g[v][0] = g[v][1] = 0;
f[u][0][0] = f[u][0][1] = f[v][0][0] = f[v][0][1] = 0;
f[u][1][0] = f[u][1][1] = f[v][1][0] = f[v][1][1] = 0;
}
il void ins(cint &u){
if(!tp)return stk[tp = 1] = u, void();
rg int lca = getlca2(u, stk[tp]);
if(lca == stk[tp])return stk[++tp] = u, void();
while(tp > 1 && fir2[stk[tp-1]] >= fir2[lca])link(stk[tp-1], stk[tp]), --tp;
if(lca^stk[tp])link(lca, stk[tp]), stk[tp] = lca;
stk[++tp] = u;
}
il void build(int *nd, cint &n){
fp(i, 1, n)ins(nd[i]);
fp(i, 2, tp)link(stk[i-1], stk[i]);
}
il LL dist(cint &u, cint &v){
if(!u || !v)return 0;
return dis3[u+n]+dis3[v+n]-2*dis3[getlca3(u+n, v+n)];
}
il LL calc(cint &u, cint &p, cint &v, cint &q){
if(!f[u][p][0] || !f[v][q][0])return 0;
rg LL res = dist(f[u][p][0], f[v][q][0])+val[f[u][p][0]]+val[f[v][q][0]];
res = max(res, dist(f[u][p][0], f[v][q][1])+val[f[u][p][0]]+val[f[v][q][1]]);
res = max(res, dist(f[u][p][1], f[v][q][0])+val[f[u][p][1]]+val[f[v][q][0]]);
res = max(res, dist(f[u][p][1], f[v][q][1])+val[f[u][p][1]]+val[f[v][q][1]]);
return res;
}
il void merge(cint &u, cint &v, cint &d){
if(!f[v][d][0])return;
if(!f[u][d][0])return f[u][d][0] = f[v][d][0], f[u][d][1] = f[v][d][1], g[u][d] = g[v][d], void();
rg LL res1 = dist(f[u][d][0], f[v][d][0])+val[f[u][d][0]]+val[f[v][d][0]];
rg LL res2 = dist(f[u][d][0], f[v][d][1])+val[f[u][d][0]]+val[f[v][d][1]];
rg LL res3 = dist(f[u][d][1], f[v][d][0])+val[f[u][d][1]]+val[f[v][d][0]];
rg LL res4 = dist(f[u][d][1], f[v][d][1])+val[f[u][d][1]]+val[f[v][d][1]];
rg LL mx = max(max(res1, res2), max(res3, res4));
mx = max(mx, max(g[u][d], g[v][d])), g[u][d] = mx;
if(mx == res1)f[u][d][1] = f[v][d][0];
else if(mx == res2)f[u][d][1] = f[v][d][1];
else if(mx == res3)f[u][d][0] = f[v][d][0];
else if(mx == res4)f[u][d][0] = f[v][d][1];
else if(mx == g[v][d])f[u][d][0] = f[v][d][0], f[u][d][1] = f[v][d][1];
}
void dfs2(int u, cn LL &len){
for(auto &x : to[u]){
dfs2(x, len);
ans = max(ans, calc(u, 0, x, 1)-2*dis2[u]+len);
ans = max(ans, calc(u, 1, x, 0)-2*dis2[u]+len);
merge(u, x, 0), merge(u, x, 1);
}
to[u].clear();
}
void divi(int nw, int m, int n, int *nd){
if(m == 1)return;
int edg, rt1, rt2, mn = 0x3f3f3f3f, tot = 0, d1 = 0, d2 = 0;
getrt(nw, -1, m, edg, mn), rt1 = e[edg].to, rt2 = e[edg^1].to, mark[edg] = mark[edg^1] = 1;
dis[rt1] = dis[rt2] = 0, dfs(rt1, 0, tot), dfs(rt2, 0, tot), build(nd, n);
fp(i, 1, n){
typ[nd[i]] = dfn[rt2] <= dfn[nd[i]] && dfn[nd[i]] <= lst[rt2];
f[nd[i]][typ[nd[i]]][0] = nd[i], val[nd[i]] = dis[nd[i]];
}
dfs2(rt, e[edg].dis);
tp = rt = 0;
fp(i, 1, n){
if(!typ[nd[i]])tmp[0][++d1] = nd[i];
else tmp[1][++d2] = nd[i];
}
fp(i, 1, d1)nd[i] = tmp[0][i];
fp(i, 1, d2)nd[i+d1] = tmp[1][i];
divi(rt1, siz[rt1], d1, nd), divi(rt2, siz[rt2], d2, nd+d1);
}
int main(){
// freopen("in", "r", stdin);
// freopen("o1", "w", stdout);
rd(n), read(t1), read(t2), read(t3);
m = n, memset(head, -1, sizeof head), rebuild(1, 0);
getdis2(1, 0);
fp(i, 1, n)t3[i].pb(mp(i+n, dis2[i]));
getdis3(1, 0);
fp(i, 2, max(len2, len3))lg[i] = lg[i>>1]+1;
init2(), init3();
fp(i, 1, n)vrt[i] = i;
sort(vrt+1, vrt+1+n, cmp), divi(1, m, n, vrt);
printf("%lld
", ans);
return 0;
}
多棵树怎么做?对第一棵树边分治,在第二棵树建出虚树,在虚树上边分治,再在第三棵树上建虚树……最后就用上面的做法做即可(真的会有人敢出这种题吗?)。假设有(k)棵树,复杂度应该是(mathcal{O}nlog^{k-2}n)。