Description
给定一棵树,树上每一个点都有一个点权。你要在这棵树上选择一个点集,需要满足树上任意一条边所连的两个端点中至少有一个端点被选择,现在你需要最小化这个点集的点权和
有多次询问,每次询问钦定两个点分别选或不选,整棵树的代价是多少
Solution
我们需要维护以下三个数组
(f[i][0/1])表示以(i)为根的子树中所有节点,(i)号点不选(/)选,所花费的最小代价是多少
(F[i][j][0/1][0/1])表示(i)号点向上跳(2^j)的父亲(f)的子树(去掉以(i)为根的子树),(i)号点不选(/)选,(f)号点不选(/)选,所花费的最小代价
(rf[i][0/1])表示以(i)号点不选(/)选,整棵树的最小代价
然后询问的时候分两种情况讨论
一种情况就是一个点为另一个点的祖先,就把较深的节点(a)倍增到另一个点(b)的儿子处,然后讨论一下(b)选(/)不选
否则,就先把两个点都倍增到它们(lca)的儿子节点处,然后讨论(lca)选(/)不选就可以了
Code
#include <bits/stdc++.h>
using namespace std;
#define fst first
#define snd second
#define squ(x) ((LL)(x) * (x))
#define debug(...) fprintf(stderr, __VA_ARGS__)
typedef long long LL;
typedef pair<int, int> pii;
inline int read() {
int sum = 0, fg = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') fg = -1;
for (; isdigit(c); c = getchar()) sum = (sum << 3) + (sum << 1) + (c ^ 0x30);
return fg * sum;
}
const int maxn = 1e5 + 10;
const LL inf = 0x3f3f3f3f3f3f3f3f;
vector<int> g[maxn];
struct node {
LL v[2][2];
node() { memset(v, 0x3f, sizeof v); }
}F[maxn][17];
int n, m, d[maxn], fa[maxn][17], w[maxn];
LL f[maxn][2], rf[maxn][2];
void dfs(int now, int _f) {
d[now] = d[_f] + 1, fa[now][0] = _f;
for (int i = 1; i <= 16; i++) fa[now][i] = fa[fa[now][i - 1]][i - 1];
f[now][0] = 0, f[now][1] = w[now];
for (int i = 0; i < g[now].size(); i++) {
int son = g[now][i];
if (son == _f) continue;
dfs(son, now);
f[now][0] += f[son][1];
f[now][1] += min(f[son][0], f[son][1]);
}
}
node merge(const node &a, const node &b) {
node res;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
for (int k = 0; k < 2; k++)
res.v[i][j] = min(res.v[i][j], a.v[i][k] + b.v[k][j]);
return res;
}
void Dfs(int now, int f0, int f1) {
rf[now][0] = f[now][0] + f1;
rf[now][1] = f[now][1] + min(f0, f1);
for (int i = 1; i <= 16; i++) F[now][i] = merge(F[now][i - 1], F[fa[now][i - 1]][i - 1]);
for (int i = 0; i < g[now].size(); i++) {
int son = g[now][i];
if (son == fa[now][0]) continue;
F[son][0].v[0][0] = inf;
F[son][0].v[1][0] = f[now][0] - f[son][1];
F[son][0].v[0][1] = F[son][0].v[1][1] = f[now][1] - min(f[son][0], f[son][1]);
Dfs(son, rf[now][0] - f[son][1], rf[now][1] - min(f[son][0], f[son][1]));
}
}
int main() {
#ifdef xunzhen
freopen("defense.in", "r", stdin);
freopen("defense.out", "w", stdout);
#endif
static char Tmp[10];
n = read(), m = read(), scanf("%s", Tmp);
for (int i = 1; i <= n; i++) w[i] = read();
for (int i = 1; i < n; i++) {
int x = read(), y = read();
g[x].push_back(y);
g[y].push_back(x);
}
dfs(1, 0);
Dfs(1, 0, 0);
while (m--) {
int a = read(), x = read(), b = read(), y = read();
if (!(x | y) && (fa[a][0] == b || fa[b][0] == a)) {
printf("-1
"); continue;
}
if (d[a] < d[b]) swap(a, b), swap(x, y);
node A, B;
A.v[x][x] = f[a][x];
for (int i = 16; ~i; i--)
if (d[fa[a][i]] > d[b])
A = merge(A, F[a][i]), a = fa[a][i];
if (fa[a][0] == b) {
LL ans0 = rf[b][0] - f[a][1], ans1 = rf[b][1] - min(f[a][0], f[a][1]);
ans0 += A.v[x][1], ans1 += min(A.v[x][0], A.v[x][1]);
printf("%lld
", y ? ans1 : ans0);
} else {
if (d[a] > d[b]) A = merge(A, F[a][0]), a = fa[a][0];
B.v[y][y] = f[b][y];
for (int i = 16; ~i; i--)
if (fa[a][i] != fa[b][i]) {
A = merge(A, F[a][i]), B = merge(B, F[b][i]);
a = fa[a][i], b = fa[b][i];
}
int lca = fa[a][0];
LL ans0 = rf[lca][0] - f[a][1] - f[b][1], ans1 = rf[lca][1] - min(f[a][0], f[a][1]) - min(f[b][0], f[b][1]);
ans0 += A.v[x][1] + B.v[y][1], ans1 += min(A.v[x][0], A.v[x][1]) + min(B.v[y][0], B.v[y][1]);
printf("%lld
", min(ans0, ans1));
}
}
return 0;
}