题意:给你一颗树,每个节点有有一个权值,每次询问从x到y的最短路上权值在c到d之间的所有的点的权值和是多少。
思路:肯定要用树剖,因为询问c到d之间这种操作树上倍增很难做,但是用其它数据结构可以比较好的查询。我们可以用线段树来进行这种操作。每次询问一个区间时,如果当前区间被查询区间完全覆盖,并且区间里的最大指小于等于d,最小值大于等于c,才返回,否则继续查询。这种做法其实可以被卡掉,比如很长的路径上点权都是1, 2, 1, 2这种,而询问的c和d都是1,这样线段树上的询问会被卡成O(n)的。我感觉比较可行的做法是离散化之后用主席树,这样可以保证O(logn),但是既然没卡这个,就懒得写这种做法了。
线段树做法:
#include <bits/stdc++.h> #define LL long long #define ls(x) (x << 1) #define rs(x) ((x << 1) | 1) using namespace std; const int maxn = 100010; int head[maxn], Next[maxn * 2], ver[maxn * 2]; int son[maxn], d[maxn], f[maxn], top[maxn], sz[maxn], dfn[maxn]; int n, m, tot, cnt; int a[maxn], b[maxn]; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } struct SegmentTree { int mx, mi; LL sum; }; SegmentTree tr[maxn * 4]; void pushup(int o) { tr[o].sum = tr[ls(o)].sum + tr[rs(o)].sum; tr[o].mi = min(tr[ls(o)].mi, tr[rs(o)].mi); tr[o].mx = max(tr[ls(o)].mx, tr[rs(o)].mx); } void build(int o, int l, int r) { if(l == r) { tr[o].mx = tr[o].mi = tr[o].sum = a[l]; return; } int mid = (l + r) >> 1; build(ls(o), l, mid); build(rs(o), mid + 1, r); pushup(o); } bool match(int o, int l, int r) { return tr[o].mi >= l && tr[o].mx <= r; } LL query(int o, int l, int r, int ql, int qr, int lb, int rb) { if(l >= ql && r <= qr && match(o, lb, rb)) { return tr[o].sum; } int mid = (l + r) >> 1; LL ans = 0; if(ql <= mid && !(tr[ls(o)].mx < lb || tr[ls(o)].mi > rb)) ans += query(ls(o), l, mid, ql, qr, lb, rb); if(qr > mid && !(tr[rs(o)].mx < lb || tr[rs(o)].mi > rb)) ans += query(rs(o), mid + 1, r, ql, qr, lb, rb); return ans; } void dfs1(int x, int fa) { sz[x] = 1; f[x] = fa; d[x] = d[fa] + 1; for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == fa) continue; dfs1(y, x); sz[x] += sz[y]; if(!son[x] || sz[y] > sz[son[x]]) son[x] = y; } } void dfs2(int x, int fa, int t) { top[x] = t; dfn[x] = ++cnt; a[dfn[x]] = b[x]; if(son[x]) dfs2(son[x], x, t); for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == fa || y == son[x]) continue; dfs2(y, x, y); } } LL solve(int l, int r, int x, int y) { LL ans = 0; while(top[l] != top[r]) { if(d[top[l]] > d[top[r]]) { ans += query(1, 1, n, dfn[top[l]], dfn[l], x, y); l = f[top[l]]; } else { ans += query(1, 1, n, dfn[top[r]], dfn[r], x, y); r = f[top[r]]; } } if(d[l] < d[r]) ans += query(1, 1, n, dfn[l], dfn[r], x, y); else ans += query(1, 1, n, dfn[r], dfn[l], x, y); return ans; } int main() { int x, y, l, r; while(~scanf("%d%d", &n, &m)) { memset(head, 0, sizeof(head)); memset(son, 0, sizeof(son)); memset(f, 0, sizeof(f)); memset(sz, 0, sizeof(sz)); memset(top, 0, sizeof(top)); memset(d, 0, sizeof(d)); tot = 0; cnt = 0; for (int i = 1; i <= n; i++) scanf("%d", &b[i]); for (int i = 1; i <= n - 1; i++) { scanf("%d%d", &x, &y); add(x, y); add(y, x); } dfs1(1, 0); dfs2(1, 0, 1); build(1, 1, n); while(m--) { scanf("%d%d%d%d", &l, &r, &x, &y); cout << solve(l, r, x, y); if(m == 0) cout << endl; else cout << " "; } } }