【树结构】树链剖分
当我们需要在一棵树上完成某些区间操作,而且要求复杂度严格保持在
所谓树链剖分,就是把树分割成链,把每条链放到线段树或其他数据结构里面维护。显然,只要我们保证每次区间操作涉及的链的个数为
对于一个区间查询如“对a, b最短路径上的所有节点权值求和”,只需要用倍增处理出c = LCA(a, b)转化为对于一个节点和他祖先节点的区间求和。之后只需要不断检查c是否在当前区间内。如果在直接调用数据结构求和,如果不在则求这一区间的总和,并将节点向上推,直到depth[a] < depth[c]。
以zjoi的树的统计一题为例给出代码:
#include <bits/stdc++.h>
using namespace std;
const int maxm = 100005, maxn = 100005;
int zkw[maxn * 4], sum[maxn * 4], N = 131072, t = 0;
void update(int i, int k)
{
i += N - 1; zkw[i] = sum[i] = k;
for (i >>= 1; i; i >>= 1) {
zkw[i] = max(zkw[i << 1], zkw[(i << 1) + 1]);
sum[i] = sum[i << 1] + sum[(i << 1) + 1];
}
}
pair<int, int> query(int i, int j)
{
int ans = 0, maxi = INT_MIN, p, q;
for (p = i + N - 1, q = j + N - 1; p < q; p >>= 1, q >>= 1) {
if (p & 1) { ans += sum[p], maxi = max(maxi, zkw[p]); p++; }
if (!(q & 1)) { ans += sum[q], maxi = max(maxi, zkw[q]); q--; }
}
if (p == q) ans += sum[p], maxi = max(maxi, zkw[p]);
return make_pair(ans, maxi);
}
struct node {
int to, next;
node() { to = next = 0; }
} edge[2 * maxm];
int head[maxn], top = 0;
int dat[maxn], siz[maxn], id[maxn], ind[maxn], hev[maxn], dep[maxn];
int father[maxn][35];
int n, m;
void push(int i, int j) { edge[++top].to = j; edge[top].next = head[i]; head[i] = top; }
int dfs1(int i)
{
siz[i] = 1;
for (int k = head[i]; k; k = edge[k].next) {
if (!dep[edge[k].to]) {
dep[edge[k].to] = dep[i] + 1; father[edge[k].to][0] = i;
siz[i] += dfs1(edge[k].to);
}
}
return siz[i];
}
void dfs2(int i, int from)
{
ind[i] = from; update(++t, dat[i]); id[i] = t;
if (!head[i]) return;
hev[i] = 0;
for (int k = head[i]; k; k = edge[k].next) {
if (dep[edge[k].to] > dep[i] && siz[edge[k].to] > siz[hev[i]])
hev[i] = edge[k].to;
}
if (!hev[i]) {return;}
dfs2(hev[i], from);
for (int k = head[i]; k; k = edge[k].next)
if (dep[edge[k].to] > dep[i] && edge[k].to != hev[i])
dfs2(edge[k].to, edge[k].to);
}
void travel(int, int);
void init()
{
dep[1] = 1;
memset(father, 0, sizeof father);
dfs1(1);
dfs2(1, 1);
for (int j = 1; j <= 20; j++)
for (int i = 1; i <= n; i++)
father[i][j] = father[father[i][j-1]][j-1];
}
inline int lowbit(int i) { return i&(-i); }
int lca(int a, int b)
{
if (dep[a] < dep[b]) swap(a, b);
int dd = dep[a] - dep[b];
while (dd) { a = father[a][(int)(log2(lowbit(dd)))]; dd -= lowbit(dd); }
if (a == b) return a;
for (int i = 20; i >= 0; i--)
if (father[a][i] != father[b][i])
a = father[a][i], b = father[b][i];
return father[a][0];
}
int query_sum(int i, int j) // j is anc of i
{
if (dep[i] < dep[j]) return 0;
if (dep[ind[i]] <= dep[j])
return query(id[j], id[i]).first;
return query_sum(father[ind[i]][0], j) + query(id[ind[i]], id[i]).first;
}
int query_max(int i, int j)
{
if (dep[i] < dep[j]) return INT_MIN;
if (dep[ind[i]] <= dep[j])
return query(id[j], id[i]).second;
return max(query_max(father[ind[i]][0], j), query(id[ind[i]], id[i]).second);
}
inline void change(int i, int j) { update(id[i], j); }
inline int read() { int a; scanf("%d", &a); return a; }
int main()
{
memset(dep, 0, sizeof dep);
memset(head, 0, sizeof head);
memset(hev, 0, sizeof hev);
memset(sum, 0, sizeof sum);
memset(zkw, -127/3, sizeof zkw);
n = read();
for (int i = 1; i < n; i++) {
int a, b; a = read(); b = read();
push(a, b);
push(b, a);
}
for (int i = 1; i <= n; i++)
dat[i] = read();
init();
m = read();
char str[10]; int a, b, c;
for (int i = 1; i <= m; i++) {
scanf("%s", str);
a = read(); b = read();
if (strcmp(str, "CHANGE") == 0) change(a, b);
else if (strcmp(str, "QSUM") == 0) {
c = lca(a, b);
printf("%d
", query_sum(a, c)+query_sum(b, c)-query_sum(c, c));
}
else {
c = lca(a, b);
printf("%d
", max(query_max(a, c), query_max(b, c)));
}
}
return 0;
}