题意简述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
支持3个操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
题解思路
树链剖分
代码
#include <iostream>
#include <algorithm>
#define int long long
const int Maxn = 30001, INF = 0x3f3f3f3f;
int n, u, v, cnt1, cnt2, q, opt;
char st[10];
int h[Maxn], to[Maxn << 1], nxt[Maxn << 1];
int fa[Maxn], dep[Maxn], sz[Maxn], hvs[Maxn], vv[Maxn], top[Maxn];
int va[Maxn], id[Maxn];
struct zkw_Segment_Tree
{
int N;
int mx[Maxn << 2], sum[Maxn << 2];
void build(const int& n)
{
for (N = 1; N < n + 2; N <<= 1);
for (register int i = 1; i <= n; ++i) mx[i + N] = sum[i + N] = va[i];
for (register int i = N; --i; )
{
sum[i] = sum[i << 1] + sum[i << 1 | 1];
mx[i] = std::max(mx[i << 1], mx[i << 1 | 1]);
}
}
void change(const int& x, const int& k)
{
sum[x + N] = mx[x + N] = k;
for (register int i = x + N; i >>= 1; )
{
sum[i] = sum[i << 1] + sum[i << 1 | 1];
mx[i] = std::max(mx[i << 1], mx[i << 1 | 1]);
}
}
int querysm(const int& l, const int& r, int s = 0)
{
for (register int i = l + N - 1, j = r + N + 1; i ^ j ^ 1; i >>= 1, j >>= 1)
{
if (~i & 1) s += sum[i ^ 1];
if ( j & 1) s += sum[j ^ 1];
}
return s;
}
int querymx(const int& l, const int& r, int s = -INF)
{
for (register int i = l + N - 1, j = r + N + 1; i ^ j ^ 1; i >>= 1, j >>= 1)
{
if (~i & 1) s = std::max(s, mx[i ^ 1]);
if ( j & 1) s = std::max(s, mx[j ^ 1]);
}
return s;
}
}sgt;
void add_edge(const int& u, const int& v)
{
to[++cnt1] = v;
nxt[cnt1] = h[u];
h[u] = cnt1;
}
void dfs1(const int& u)
{
sz[u] = 1;
for (register int i = h[u]; i; i = nxt[i])
if (to[i] ^ fa[u])
{
fa[to[i]] = u;
dep[to[i]] = dep[u] + 1;
dfs1(to[i]);
sz[u] += sz[to[i]];
if (sz[to[i]] > sz[hvs[u]]) hvs[u] = to[i];
}
}
void dfs2(const int& u, const int& tp)
{
top[u] = tp; id[u] = ++cnt2; va[cnt2] = vv[u];
if (hvs[u]) dfs2(hvs[u], tp);
for (register int i = h[u]; i; i = nxt[i])
if (to[i] ^ fa[u] && to[i] ^ hvs[u])
dfs2(to[i], to[i]);
}
void change(const int& u, const int& t) {sgt.change(id[u], t); }
int querysm(int u, int v, int s = 0)
{
for (; top[u] ^ top[v]; u = fa[top[u]])
{
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
s += sgt.querysm(id[top[u]], id[u]);
}
if (dep[u] > dep[v]) std::swap(u, v);
s += sgt.querysm(id[u], id[v]);
return s;
}
int querymx(int u, int v, int s = -INF)
{
for (; top[u] ^ top[v]; u = fa[top[u]])
{
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
s = std::max(s, sgt.querymx(id[top[u]], id[u]));
}
if (dep[u] > dep[v]) std::swap(u, v);
s = std::max(s, sgt.querymx(id[u], id[v]));
return s;
}
signed main()
{
std::ios::sync_with_stdio(0);
std::cin >> n;
for (register int i = 1; i < n; ++i)
{
std::cin >> u >> v;
add_edge(u, v); add_edge(v, u);
}
for (register int i = 1; i <= n; ++i) std::cin >> vv[i];
fa[1] = 1; dfs1(1); dfs2(1, 1); sgt.build(n);
std::cin >> q;
for (register int i = 1; i <= q; ++i)
{
std::cin >> st >> u >> v;
opt = st[1] == 'H' ? 0 : st[1] == 'M' ? 1 : 2;
if (opt == 0) change(u, v);
else if (opt == 1) printf("%lld
", querymx(u, v));
else printf("%lld
", querysm(u, v));
}
}