[hihoCoder#1381]Little Y's Tree
试题描述
小Y有一棵n个节点的树,每条边都有正的边权。
小J有q个询问,每次小J会删掉这个树中的k条边,这棵树被分成k+1个连通块。小J想知道每个连通块中最远点对距离的和。
这里的询问是互相独立的,即每次都是在小Y的原树上进行操作。
输入
第一行一个整数n,接下来n-1行每行三个整数u,v,w,其中第i行表示第i条边边权为wi,连接了ui,vi两点。
接下来一行一个整数q,表示有q组询问。
对于每组询问,第一行一个正整数k,接下来一行k个不同的1到n-1之间的整数,表示删除的边的编号。
1<=n,q,Σk<=105, 1<=w<=109
输出
共q行,每行一个整数表示询问的答案。
输入示例
5 1 2 2 2 3 3 2 4 4 4 5 2 3 4 1 2 3 4 1 4 2 2 3
输出示例
0 7 4
数据规模及约定
见“输入”
题解
我们将每次删除的所有边的深度较大的节点作为关键点,建立虚树。然后我们发现我们可以维护一下区间的连通性,将所有节点按照 dfs 序从小到大排序以后,用线段树合并连通信息。对于两个连通块 A 和 B,若 A 的直径为 (a, b),B 的直径为 (c, d),那么 A U B 的直径就是 (a, b), (c, d), (a, c), (a, d), (b, c), (b, d) 六种情况,我们取一个最大值即可。对于一颗虚树,我们按照深度从大到小依次查询该节点对应区间的连通块直径,累计答案,再将这个区间打上删除标记,最后记得要恢复删除标记。
对了,这题要 O(1) 求 LCA。
#include <iostream> #include <cstdio> #include <cstring> #include <cstdlib> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 100010 #define maxm 200010 #define maxlog 20 #define oo 2147483647 #define LL long long int n, m, head[maxn], Next[maxm], To[maxm], dist[maxm]; struct Edge { int a, b, c; Edge() {} Edge(int _1, int _2, int _3): a(_1), b(_2), c(_3) {} } es[maxn]; void AddEdge(int a, int b, int c) { To[++m] = b; dist[m] = c; Next[m] = head[a]; head[a] = m; swap(a, b); To[++m] = b; dist[m] = c; Next[m] = head[a]; head[a] = m; return ; } int dep[maxn], list[maxm], cl, pos[maxn], ord[maxn], clo, ordr[maxn], Pos[maxn]; LL Dep[maxn]; void build(int u, int fa) { list[++cl] = u; pos[u] = cl; ord[u] = ++clo; Pos[clo] = u; for(int e = head[u]; e; e = Next[e]) if(To[e] != fa) { dep[To[e]] = dep[u] + 1; Dep[To[e]] = Dep[u] + dist[e]; build(To[e], u); list[++cl] = u; } ordr[u] = clo; return ; } int Lca[maxlog][maxm], Log[maxm]; void rmq_init() { Log[1] = 0; for(int i = 2; i <= cl; i++) Log[i] = Log[i>>1] + 1; for(int i = 1; i <= cl; i++) Lca[0][i] = list[i]; for(int j = 1; (1 << j) <= cl; j++) for(int i = 1; i + (1 << j) - 1 <= cl; i++) { int a = Lca[j-1][i], b = Lca[j-1][i+(1<<j-1)]; if(dep[a] < dep[b]) Lca[j][i] = a; else Lca[j][i] = b; } return ; } int lca(int a, int b) { int l = min(pos[a], pos[b]), r = max(pos[a], pos[b]); int t = Log[r-l+1]; a = Lca[t][l]; b = Lca[t][r-(1<<t)+1]; return dep[a] < dep[b] ? a : b; } LL calc(int a, int b) { return Dep[a] + Dep[b] - (Dep[lca(a,b)] << 1); } struct Node { bool hasv; int A, B; LL Len; Node() {} Node(int _1, int _2, LL _3, bool _4): A(_1), B(_2), Len(_3), hasv(_4) {} } ns[maxn<<2]; void maintain(Node& o, Node lc, Node rc) { o.Len = -1; if(!lc.hasv && !rc.hasv){ o.Len = o.hasv = 0; return ; } if(!lc.hasv) { o = rc; return ; } if(!rc.hasv) { o = lc; return ; } o.hasv = 1; if(o.Len < lc.Len) o.Len = lc.Len, o.A = lc.A, o.B = lc.B; if(o.Len < rc.Len) o.Len = rc.Len, o.A = rc.A, o.B = rc.B; LL d = calc(lc.A, rc.A); if(o.Len < d) o.Len = d, o.A = lc.A, o.B = rc.A; d = calc(lc.A, rc.B); if(o.Len < d) o.Len = d, o.A = lc.A, o.B = rc.B; d = calc(lc.B, rc.A); if(o.Len < d) o.Len = d, o.A = lc.B, o.B = rc.A; d = calc(lc.B, rc.B); if(o.Len < d) o.Len = d, o.A = lc.B, o.B = rc.B; return ; } void build(int L, int R, int o) { if(L == R) ns[o] = Node(Pos[L], Pos[R], 0, 1); else { int M = L + R >> 1, lc = o << 1, rc = lc | 1; build(L, M, lc); build(M+1, R, rc); maintain(ns[o], ns[lc], ns[rc]); ns[o].hasv = 1; } // printf("seg[%d, %d]: %d %d %lld ", L, R, ns[o].A, ns[o].B, ns[o].Len); return ; } void update(int L, int R, int o, int ql, int qr, bool v) { if(ql <= L && R <= qr) ns[o].hasv = v; else { int M = L + R >> 1, lc = o << 1, rc = lc | 1; if(ql <= M) update(L, M, lc, ql, qr, v); if(qr > M) update(M+1, R, rc, ql, qr, v); maintain(ns[o], ns[lc], ns[rc]); } return ; } Node query(int L, int R, int o, int ql, int qr) { if(ql <= L && R <= qr) return ns[o]; int M = L + R >> 1, lc = o << 1, rc = lc | 1; Node ans(-1, -1, -1, 0); if(ql <= M) { Node tmp = query(L, M, lc, ql, qr), tt(0, 0, -1, 1); maintain(tt, tmp, ans); if(tmp.hasv) ans = tt; } if(qr > M) { Node tmp = query(M+1, R, rc, ql, qr), tt(0, 0, -1, 1); maintain(tt, tmp, ans); if(tmp.hasv) ans = tt; } // printf("[%d, %d] %d %d %d %lld(%d) ans: %lld ", L, R, o, ql, qr, ns[o].Len, ns[o].hasv, ans.Len); return ans; } bool cmp(int a, int b) { return pos[a] < pos[b]; } int psi[maxn], cpi, ps[maxn], cp, vis[maxn]; bool flg[maxn]; struct Vtree { int m, head[maxn], Next[maxm], To[maxm]; LL dist[maxm], ans; void init() { ans = m = 0; return ; } void AddEdge(int a, int b, LL c) { // printf("Add2: %d %d %lld ", a, b, c); To[++m] = b; dist[m] = c; Next[m] = head[a]; head[a] = m; swap(a, b); To[++m] = b; dist[m] = c; Next[m] = head[a]; head[a] = m; return ; } void dfs(int u, int fa) { for(int e = head[u]; e; e = Next[e]) if(To[e] != fa) dfs(To[e], u); if(flg[u]) { Node tmp = query(1, n, 1, ord[u], ordr[u]); flg[u] = 0; if(tmp.hasv) ans += tmp.Len; // printf("at %d tmp: %lld [%d, %d] ", u, tmp.Len, ord[u], ordr[u]); update(1, n, 1, ord[u], ordr[u], 0); } return ; } void dfs2(int u, int fa) { for(int e = head[u]; e; e = Next[e]) if(To[e] != fa) dfs2(To[e], u); update(1, n, 1, ord[u], ordr[u], 1); head[u] = 0; return ; } } vt; int main() { n = read(); for(int i = 1; i < n; i++) { int a = read(), b = read(), c = read(); AddEdge(a, b, c); es[i] = Edge(a, b, c); } build(1, 0); rmq_init(); // for(int i = 1; i <= cl; i++) printf("%d%c", Lca[0][i], i < cl ? ' ' : ' '); // for(int i = 1; i <= n; i++) printf("%d%c", pos[i], i < n ? ' ' : ' '); build(1, n, 1); int q = read(); // for(int i = 1; i <= n; i++) printf("%d%c", ord[i], i < n ? ' ' : ' '); while(q--) { int cpi = read(); cp = 0; for(int i = 1; i <= cpi; i++) { int e = read(), u = dep[es[e].a] < dep[es[e].b] ? es[e].b : es[e].a; ps[++cp] = psi[i] = u; vis[ps[cp]] = q + 1; flg[ps[cp]] = 1; } if(vis[1] != q + 1) ps[++cp] = 1, flg[1] = 1; sort(psi + 1, psi + cpi + 1, cmp); for(int i = 1; i < cpi; i++) { int c = lca(psi[i], psi[i+1]); if(vis[c] != q + 1) vis[c] = q + 1, ps[++cp] = c; } sort(ps + 1, ps + cp + 1, cmp); vt.init(); for(int i = 1; i < cp; i++) { int a = ps[i], b = ps[i+1], c = lca(a, b); vt.AddEdge(b, c, Dep[b] - Dep[c]); } vt.dfs(1, 0); vt.dfs2(1, 0); printf("%lld ", vt.ans); } return 0; }