一个有点权和边权的二叉树,多次询问点权在 $[L,R]$ 的点到 $u$ 的距离和
$n,q leq 100000$
sol:
1.点分治
建出分治树的结构,考虑计算距离的过程
我们知道 $dis(u,v) = dep_u + dep_v - 2 imes dep_{lca}$
因为树高是 logn 的,所以可以暴力爬树高枚举 lca
把点权差分一下
对每层重心开 $3$ 个 vector 表示前 $i$ 种颜色到它的距离和,前 $i$ 种颜色到它父亲的距离和,前 $i$ 种颜色的点数
因为树高是 $O(logn)$ 的,这个空间是 $O(nlogn)$ 的
每次二分找到 $[L,R]$ 在当前 vector 里的位置,算一下距离就可以了
顺便吐槽为什么我每次点分治都懒得写 $O(1)$ lca
#include <bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0, f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = 10 * x + ch - '0'; return x * f; } const int maxn = 200010; int n, q, A, yr[maxn]; int first[maxn], to[maxn << 1], nx[maxn << 1], val[maxn << 1], cnt; inline void add(int u, int v, int w) { to[++cnt] = v; nx[cnt] = first[u]; first[u] = cnt; val[cnt] = w; } LL dis[maxn]; int fa[maxn]; namespace LCA { int dep[maxn], bl[maxn], size[maxn]; inline void dfs1(int x) { size[x] = 1; for (int i = first[x]; i; i = nx[i]) { if (to[i] == fa[x]) continue; fa[to[i]] = x; dep[to[i]] = dep[x] + 1; dis[to[i]] = dis[x] + val[i]; dfs1(to[i]); size[x] += size[to[i]]; } } inline void dfs2(int x, int col) { int k = 0; bl[x] = col; for (int i = first[x]; i; i = nx[i]) if (dep[to[i]] > dep[x] && size[to[i]] > size[k]) k = to[i]; if (!k) return; dfs2(k, col); for (int i = first[x]; i; i = nx[i]) if (dep[to[i]] > dep[x] && to[i] != k) dfs2(to[i], to[i]); } inline int lca(int x, int y) { while (bl[x] != bl[y]) { if (dep[bl[x]] < dep[bl[y]]) swap(x, y); x = fa[bl[x]]; } return dep[x] > dep[y] ? y : x; } } // namespace LCA struct Node { LL col, sum, sig, cnt; inline bool operator<(const Node &b) const { return col < b.col; } }; vector<Node> G[maxn]; inline LL caldis(int x, int y) { // cout<<dis[x] + dis[y] - 2 * dis[LCA::lca(x,y)]<<endl; if (!x || !y) return 0; return dis[x] + dis[y] - 2 * dis[LCA::lca(x, y)]; } int f[maxn], size[maxn], vis[maxn], par[maxn], sig, root; void findroot(int x, int fa) { f[x] = 0, size[x] = 1; for (int i = first[x]; i; i = nx[i]) { if (to[i] == fa || vis[to[i]]) continue; findroot(to[i], x); size[x] += size[to[i]]; f[x] = max(f[x], size[to[i]]); } f[x] = max(f[x], sig - size[x]); if (f[x] < f[root]) root = x; } void add_node(int x, int fa, int rt) { G[rt].push_back((Node){ yr[x], caldis(x, rt), (par[rt] ? caldis(x, par[rt]) : 0), 1 }); for (int i = first[x]; i; i = nx[i]) { if (to[i] == fa || vis[to[i]]) continue; add_node(to[i], x, rt); } } void build(int x) { vis[x] = 1; add_node(x, 0, x); G[x].push_back((Node){ -1, 0, 0, 0 }); sort(G[x].begin(), G[x].end()); for (int i = 1; i < G[x].size(); i++) { G[x][i].sum += G[x][i - 1].sum; G[x][i].sig += G[x][i - 1].sig; G[x][i].cnt += G[x][i - 1].cnt; } for (int i = first[x]; i; i = nx[i]) { if (vis[to[i]]) continue; root = 0; sig = size[to[i]]; findroot(to[i], 0); par[root] = x; build(root); } } LL query(int x, int ql, int qr) { LL ans = 0; for (int i = x; i; i = par[i]) { int st, ed; int l = 0, r = G[i].size() - 1; while (l <= r) { int mid = (l + r) >> 1; if (G[i][mid].col <= qr) l = mid + 1; else r = mid - 1; } ed = l - 1; l = 0, r = G[i].size() - 1; while (l <= r) { int mid = (l + r) >> 1; if (G[i][mid].col <= ql - 1) l = mid + 1; else r = mid - 1; } st = l - 1; // cout<<st<<" "<<ed<<endl; ans += (G[i][ed].sum - G[i][st].sum); if (i != x) ans += (G[i][ed].cnt - G[i][st].cnt) * caldis(i, x); if (par[i]) ans -= (G[i][ed].sig - G[i][st].sig) + (G[i][ed].cnt - G[i][st].cnt) * caldis(x, par[i]); } return ans; } int main() { n = read(), q = read(), A = read(); for (int i = 1; i <= n; i++) yr[i] = read(); for (int i = 2; i <= n; i++) { int u = read(), v = read(), w = read(); add(u, v, w); add(v, u, w); } LCA::dep[1] = 1; LCA::dfs1(1); LCA::dfs2(1, 1); sig = n; size[0] = f[0] = 2147483233; findroot(1, 0); build(root); LL lastans = 0; while (q--) { int x = read(), a = read(), b = read(); int l = min((a + lastans) % A, (b + lastans) % A); int r = max((a + lastans) % A, (b + lastans) % A); printf("%lld ", lastans = query(x, l, r)); } }
2.主席树
以点权为版本开主席树,还是考虑计算距离,发现 $dep_u$ 和 $dep_v$ 都可以直接查,$dep_lca$ 的话,不好查
可以把所有 $v$ 到根的路径全 $+1$,然后询问的时候从每个 $u$ 走到根,在相应的 $v$ 的线段树上查到根距离就可以了
主席树跟上一种做法一样,也是开一个关于点权前缀的,查询的时候减一下
#include <bits/stdc++.h> #define LL long long using namespace std; inline int read() { int x = 0, f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = 10 * x + ch - '0'; return x * f; } const int maxn = 200010; int n, q, A; struct Node { int yr, id; bool operator<(const Node &b) const { return (yr == b.yr) ? (id < b.id) : (yr < b.yr); } } ns[maxn]; struct TrNode { int ls, rs, tms; LL val; } t[maxn << 8]; int first[maxn], to[maxn << 1], nx[maxn << 1], val[maxn << 1], cnt; inline void add(int u, int v, int w) { to[++cnt] = v; nx[cnt] = first[u]; first[u] = cnt; val[cnt] = w; } int ToT, root[maxn]; LL sum[maxn], dis[maxn], dsum[maxn]; int fa[maxn], pos[maxn], dfn; int dep[maxn], bl[maxn], size[maxn]; inline void dfs1(int x) { size[x] = 1; for (int i = first[x]; i; i = nx[i]) { if (to[i] == fa[x]) continue; fa[to[i]] = x; dep[to[i]] = dep[x] + 1; dis[to[i]] = dis[x] + val[i]; dfs1(to[i]); size[x] += size[to[i]]; } } inline void dfs2(int x, int col) { int k = 0; bl[x] = col; pos[x] = ++dfn; sum[dfn] = dis[x] - dis[fa[x]]; // cout<<sum[dfn] << endl; for (int i = first[x]; i; i = nx[i]) if (dep[to[i]] > dep[x] && size[to[i]] > size[k]) k = to[i]; if (!k) return; dfs2(k, col); for (int i = first[x]; i; i = nx[i]) if (dep[to[i]] > dep[x] && to[i] != k) dfs2(to[i], to[i]); } inline void build(int &x, int l, int r) { x = ++ToT; if (l == r) return; int mid = (l + r) >> 1; build(t[x].ls, l, mid); build(t[x].rs, mid + 1, r); } inline void Insert(int &x, int l, int r, int L, int R) { t[++ToT] = t[x]; if (L <= l && r <= R) { t[x = ToT].tms++; return; } t[x = ToT].val += sum[min(R, r)] - sum[max(l - 1, L - 1)]; int mid = (l + r) >> 1; /* if (R <= mid) Insert(t[x].ls,l,mid,L,R); else if (L > mid) Insert(t[x].rs,mid + 1,r,L,R); else Insert(t[x].ls,l,mid,L,mid),Insert(t[x].rs,mid + 1,r,mid + 1,R);*/ if (L <= mid) Insert(t[x].ls, l, mid, L, R); if (R > mid) Insert(t[x].rs, mid + 1, r, L, R); } inline LL query(int x, int l, int r, int L, int R) { LL res = 1LL * (sum[min(R, r)] - sum[max(l - 1, L - 1)]) * t[x].tms; if (L <= l && r <= R) return res + t[x].val; int mid = (l + r) >> 1; /* if (R <= mid) return res + query(t[x].ls,l,mid,L,R); else if (L > mid) return res + query(t[x].rs,mid + 1,r,L,R); else return res + query(t[x].ls,l,mid,L,mid) + query(t[x].rs,mid + 1,r,mid + 1,R);*/ if (L <= mid) res += query(t[x].ls, l, mid, L, R); if (R > mid) res += query(t[x].rs, mid + 1, r, L, R); return res; } inline LL ask(int u, int v) { LL res = 0; while (bl[u] != 1) { res += query(root[v], 1, n, pos[bl[u]], pos[u]); u = fa[bl[u]]; } res += query(root[v], 1, n, 1, pos[u]); return res; } inline void add(int u, int v) { while (bl[u] != 1) { Insert(root[v], 1, n, pos[bl[u]], pos[u]); u = fa[bl[u]]; } Insert(root[v], 1, n, 1, pos[u]); } int main() { n = read(), q = read(), A = read(); for (int i = 1; i <= n; i++) ns[i].yr = read(), ns[i].id = i; sort(ns + 1, ns + n + 1); for (int i = 2; i <= n; i++) { int u = read(), v = read(), w = read(); add(u, v, w); add(v, u, w); } dfs1(1); dfs2(1, 1); for (int i = 1; i <= n; i++) sum[i] += sum[i - 1], dsum[i] = dsum[i - 1] + dis[ns[i].id]; build(root[0], 1, n); for (int i = 1; i <= n; i++) { int u = ns[i].id; root[i] = root[i - 1]; while (bl[u] != 1) { Insert(root[i], 1, n, pos[bl[u]], pos[u]); u = fa[bl[u]]; } Insert(root[i], 1, n, 1, pos[u]); } LL lastans = 0; while (q--) { int u = read(), a = read(), b = read(); int l = min((a + lastans) % A, (b + lastans) % A); int r = max((a + lastans) % A, (b + lastans) % A); l = lower_bound(ns + 1, ns + n + 1, (Node){ l, 0 }) - ns; r = upper_bound(ns + 1, ns + n + 1, (Node){ r, n }) - ns - 1; // cout<<l<<" "<<r<<endl; printf("%lld ", lastans = 1LL * (r - l + 1) * dis[u] + dsum[r] - dsum[l - 1] - 2 * (ask(u, r) - ask(u, l - 1))); } }