题意
给一颗根为(1)的有根树,树上每个点的权值为(w_i),大小为(a_i)
有(q)个询问,给出两个参数(x,s)
询问在以(x)为根的子树中,选出若干个点,这些点的大小之和不超过(s),并最大化权值之和
解法
一个明显的(O(NS^2))的树形背包暴力
设(f[x][k])为以(x)为根的子树中大小和小于(k)的结点的最大权值和,转移也很显然
我们可以发现对于一个点(x),它的状态是由所有子节点的状态合并转移过来的
于是可以考虑启发式合并的小(trick)
把整棵树进行轻重链剖分,对于一个节点,先把它重儿子的状态复制上去,对于轻儿子暴力转移:具体来说,就是把轻儿子为根的子树中的每个节点视作一个物品,背包即可
有两种证明复杂度是(O(NSlogN))的方法:
第一种:
考虑启发式合并时,每次合并以后子树大小至少变为原来的两倍,所以每个节点最多被合并(logN)次。合并单个节点的复杂度是(O(S))的,所以总复杂度为(O (NSlogN))
第二种:
考虑轻重链剖分的性质,对于任意一个节点,其到根节点的路径上有(logN)级别个轻路径,因此每个节点最多被合并(logN)次。所以复杂度得证
记住对于这类题目的模型:
树上的问题,状态由子节点状态合并转移而来,可以考虑轻重链剖分启发式合并,重儿子状态直接继承,轻儿子状态暴力转移
代码
#include <cstdio>
#include <cctype>
#include <cstring>
using namespace std;
int read();
const int N = 4e5 + 10;
int n, q;
int w[N], a[N];
int cap;
int head[N], to[N << 1], nxt[N << 1];
int sz[N], son[N];
long long f[5010][5010];
inline void add(int x, int y) {
to[++cap] = y, nxt[cap] = head[x], head[x] = cap;
to[++cap] = x, nxt[cap] = head[y], head[y] = cap;
}
inline long long max(long long x, long long y) {
return x > y ? x : y;
}
void DFS3(int x, int fa, long long *f) {
for (int i = 5000; i >= a[x]; --i) f[i] = max(f[i], f[i - a[x]] + w[x]);
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
DFS3(to[i], x, f);
}
}
void DFS2(int x, int fa) {
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
DFS2(to[i], x);
}
memcpy(f[x], f[son[x]], sizeof f[son[x]]);
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa || to[i] == son[x]) continue;
DFS3(to[i], x, f[x]);
}
for (int i = 5000; i >= a[x]; --i) f[x][i] = max(f[x][i], f[x][i - a[x]] + w[x]);
}
void DFS(int x, int fa) {
sz[x]++;
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] == fa) continue;
DFS(to[i], x);
sz[x] += sz[to[i]];
if (sz[to[i]] > sz[son[x]]) son[x] = to[i];
}
}
int main() {
n = read();
int u, v;
for (int i = 1; i < n; ++i) {
u = read(), v = read();
add(u, v);
}
for (int i = 1; i <= n; ++i)
scanf("%d%d", w + i, a + i);
DFS(1, 0);
DFS2(1, 0);
q = read();
while (q--) {
u = read(), v = read();
printf("%lld
", f[u][v]);
}
return 0;
}
int read() {
int x = 0, c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
return x;
}