题目大意
给你一棵以1为根的树,询问树上两点之间的路径点权在[a,b]之间的点权和。
(假的)解题思路1 树剖+主席树
看到题目就想到树剖+主席树,树剖之后按照dfs序建主席树,主席树用来维护权值信息,树剖用来查询链上的权值和(常数爆炸,卡时间过的)。
const int maxn = 1e5+10;
vector<int> e[maxn];
int dep[maxn], fa[maxn], sz[maxn], son[maxn];
void dfs1(int u, int p) {
sz[u] = 1; son[u] = 0;
for (auto v : e[u]) {
if (v==p) continue;
dep[v] = dep[u]+1;
fa[v] = u;
dfs1(v, u);
sz[u] += sz[v];
if (sz[v]>sz[son[u]]) son[u] = v;
}
}
int top[maxn], tim, id[maxn], rev[maxn];
void dfs2(int u, int t) {
top[u] = t;
id[u] = ++tim;
rev[tim] = u;
if (!son[u]) return;
dfs2(son[u], t);
for (auto v : e[u]) {
if (v!=fa[u] && v!=son[u]) dfs2(v, v);
}
}
struct Node {
int l, r; ll sum;
} hjt[maxn*150];
int tot, n, m, val[maxn], rt[maxn];
void insert(int pre, int &now, int l, int r, int val) {
now = ++tot;
hjt[tot]= {0, 0, 0};
hjt[now] = hjt[pre];
hjt[now].sum += val;
if (l==r) return;
int mid = (l+r)>>1;
if (val<=mid) insert(hjt[pre].l, hjt[now].l, l, mid, val);
else insert(hjt[pre].r, hjt[now].r, mid+1, r, val);
}
ll query(int pre, int now, int l, int r, int L, int R) {
if (l>=L && r<=R) return hjt[now].sum-hjt[pre].sum;
int mid = (l+r)>>1; ll res = 0;
if (L<=mid) res += query(hjt[pre].l, hjt[now].l, l, mid, L, R);
if (R>mid) res += query(hjt[pre].r, hjt[now].r, mid+1, r, L, R);
return res;
}
ll ask(int u, int v, int a, int b) {
ll sum = 0;
while(top[u]!=top[v]) {
if (dep[top[u]]<dep[top[v]]) swap(u, v);
sum += query(rt[id[top[u]]-1], rt[id[u]], 1, 1e9+10, a, b);
u = fa[top[u]];
}
if (dep[u]>dep[v]) swap(u, v);
sum += query(rt[id[u]-1], rt[id[v]], 1, 1e9+10, a, b);
return sum;
}
void init() {
tim = tot = 0;
for (int i = 0; i<=n; ++i) {
e[i].clear();
id[i] = rev[i] = dep[i] = son[i] = rt[i] = top[i] = fa[i] = 0;
}
}
int main() {
while(~scanf("%d%d", &n, &m)) {
init();
for (int i = 1; i<=n; ++i) scanf("%d", &val[i]);
for (int i = 1, a, b; i<n; ++i) {
scanf("%d%d", &a, &b);
e[a].push_back(b);
e[b].push_back(a);
}
dfs1(1, -1); dfs2(1, 1);
for (int i = 1; i<=n; ++i) insert(rt[i-1], rt[i], 1, 1e9+10, val[rev[i]]);
for (int i = 1; i<=m; ++i) {
int s, t, a, b; scanf("%d%d%d%d", &s, &t, &a, &b);
printf("%lld%c", ask(s, t, a, b), i==m ? '
':' ');
}
}
return 0;
}
解题思路2 动态开点线段树
因为只涉及询问没有修改,所以可以从离线的角度考虑。我们对这棵树dfs的时候,第一次访问当前的点把它的权值加进权值线段树里,回溯的时候删掉它的权值,那么我们在第一次访问某个点的时候线段树存的始终是1到当前点的权值信息,这样的话就能想到(ans = val[1, u]+val[1, v]-2 imes val[1,lca] + val[lca] imes chose[lca]),也就是可以把答案拆成1到u的[a,b]之间的权值和,1到v的[a,b]之间的权值和,以及1到lca的[a,b]之间的权值和,注意本来lca的值是被减两次的,但如果其值在范围外的话,就不用再加val[lca]了。
const int maxn = 1e5+10;
vector<int> e[maxn];
int f[maxn][21], tot;
struct INFO {
int l, r; ll sum;
} tr[maxn*100];
struct Q {
int l, r, id;
};
vector<Q> q[maxn], qq[maxn];
ll ans[maxn];
void insert(int rt, int l, int r, int pos, int val) {
if (l==r) {
tr[rt].sum += val;
return;
}
int mid = (l+r)>>1;
if (pos<=mid) {
if (!tr[rt].l) {
tr[rt].l = ++tot;
tr[tot] = {0, 0, 0};
}
insert(tr[rt].l, l, mid, pos, val);
}
else {
if (!tr[rt].r) {
tr[rt].r = ++tot;
tr[tot] = {0, 0, 0};
}
insert(tr[rt].r, mid+1, r, pos, val);
}
tr[rt].sum = tr[tr[rt].l].sum+tr[tr[rt].r].sum;
}
ll query(int rt, int l, int r, int L, int R) {
if (l>=L && r<=R) return tr[rt].sum;
int mid = (l+r)>>1; ll sum = 0;
if (L<=mid) sum += query(tr[rt].l, l, mid, L, R);
if (R>mid) sum += query(tr[rt].r, mid+1, r, L, R);
return sum;
}
int n, m, val[maxn], dep[maxn];
void dfs1(int u, int p) {
for (auto v : e[u]) {
if (v==p) continue;
dep[v] = dep[u]+1;
f[v][0] = u;
for (int i = 1; i<=20; ++i) f[v][i] = f[f[v][i-1]][i-1];
dfs1(v, u);
}
}
void dfs2(int u, int p) {
insert(1, 1, 1e9+10, val[u], val[u]);
for (auto t : q[u]) ans[t.id] += query(1, 1, 1e9+10, t.l, t.r);
for (auto t : qq[u]) {
ans[t.id] -= 2*query(1, 1, 1e9+10, t.l, t.r);
if (val[u]>=t.l && val[u]<=t.r) ans[t.id] += val[u];
}
for (auto v : e[u]) {
if (v==p) continue;
dfs2(v, u);
}
insert(1, 1, 1e9+10, val[u], -val[u]);
}
int lca(int u, int v) {
if (dep[u]<dep[v]) swap(u, v);
for (int i = 20; i>=0; --i)
if (dep[f[u][i]]>=dep[v]) u = f[u][i];
if (u==v) return u;
for (int i = 20; i>=0; --i)
if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
return f[u][0];
}
void init() {
clr(f, 0); tot = 0;
for (int i = 0; i<=n; ++i) e[i].clear();
for (int i = 0; i<=m; ++i) q[i].clear(), qq[i].clear(), ans[i] = 0;
}
int main() {
while(~scanf("%d%d", &n, &m)) {
init();
for (int i = 1; i<=n; ++i) scanf("%d", &val[i]);
for (int i = 1, a, b; i<n; ++i) {
scanf("%d%d", &a, &b);
e[a].push_back(b);
e[b].push_back(a);
}
tr[++tot] = {0, 0, 0};
dep[1] = 1;
dfs1(1, -1);
for (int i = 1; i<=m; ++i) {
int s, t, a, b; scanf("%d%d%d%d", &s, &t, &a, &b);
q[s].push_back({a, b, i});
q[t].push_back({a, b, i});
qq[lca(s, t)].push_back({a, b, i});
//cout << s << ' ' << t << ' ' << lca(s, t) << endl;
}
dfs2(1, -1);
for (int i = 1; i<=m; ++i) printf(i==m ? "%lld
":"%lld ", ans[i]);
}
return 0;
}
解题思路3 树剖+树状数组
虽然比上面多了一个log但是却快了很多。。。询问可以拆成[a,b] = [1,b]-[1,a-1]的形式,这样把插入的权值按从小到大排序之后就只需求一个树上两点的权值和即可。具体做法是分别将点权和询问的左右端点按从小到大排序,树剖之后再把权值小于等于询问端点的权值对应的点的dfs序加进树状数组里,就可以用树剖查询两点之间的权值和了。
const int maxn = 2e5+10;
int n, m;
vector<int> e[maxn];
P a[maxn];
struct INFO {
int u, v, id, x, y;
} q[maxn];
ll c[maxn], ans[maxn];
void add(int x, int y) {
while(x<maxn) {
c[x] += y;
x += x&-x;
}
}
ll ask(int x) {
ll sum = 0;
while(x) {
sum += c[x];
x -= x&-x;
}
return sum;
}
int dep[maxn], fa[maxn], sz[maxn], son[maxn];
void dfs1(int u, int p) {
sz[u] = 1;
for (auto v : e[u]) {
if (v==p) continue;
dep[v] = dep[u]+1;
fa[v] = u;
dfs1(v, u);
sz[u] += sz[v];
if (sz[v]>sz[son[u]]) son[u] = v;
}
}
int top[maxn], tim, id[maxn], rev[maxn];
void dfs2(int u, int t) {
top[u] = t;
id[u] = ++tim;
rev[tim] = u;
if (!son[u]) return;
dfs2(son[u], t);
for (auto v : e[u]) {
if (v!=fa[u] && v!=son[u]) dfs2(v, v);
}
}
ll query(int u, int v) {
ll sum = 0;
while(top[u]!=top[v]) {
if (dep[top[u]]>=dep[top[v]]) {
sum += ask(id[u])-ask(id[top[u]]-1);
u = fa[top[u]];
}
else {
sum += ask(id[v])-ask(id[top[v]]-1);
v = fa[top[v]];
}
}
if (dep[u]<=dep[v]) sum += ask(id[v])-ask(id[u]-1);
else sum += ask(id[u])-ask(id[v]-1);
return sum;
}
void init() {
clr(c, 0); clr(ans, 0); clr(son, 0); tim = 0;
for (int i = 0; i<=n; ++i) e[i].clear();
}
int main() {
while(~scanf("%d%d", &n, &m)) {
init();
for (int i = 1; i<=n; ++i) scanf("%d", &a[i].x), a[i].y = i;
sort(a+1, a+n+1);
for (int i = 1, a, b; i<n; ++i) {
scanf("%d%d", &a, &b);
e[a].push_back(b);
e[b].push_back(a);
}
dep[1] = 1;
dfs1(1, -1); dfs2(1, 1);
int tot = 0;
for (int i = 1, s, t, a, b; i<=m; ++i) {
scanf("%d%d%d%d", &s, &t, &a, &b);
q[++tot] = {s, t, i, a-1, -1};
q[++tot] = {s, t, i, b, 1};
}
sort(q+1, q+tot+1, [](INFO a, INFO b) {return a.x<b.x;});
int p = 1;
for (int i = 1; i<=tot; ++i) {
while(p<=n && a[p].x<=q[i].x) add(id[a[p].y], a[p].x), p++;
ans[q[i].id] += 1LL*q[i].y*query(q[i].u, q[i].v);
//cout << q[i].id << ' ' << ans[q[i].id] << ' ' << q[i].x << ' ' << ask(1024) << endl;
}
for (int i = 1; i<=m; ++i) printf(i==m ? "%lld
":"%lld ", ans[i]);
}
return 0;
}