给定一棵大小为 (n) 的树,第 (i) 个点的点权为 (2^i) ,删掉 (k) 个点及其连边,使得剩下的点组成一个连通块,且权值和最大,输出要删掉的点
(n, kleq10^6)
贪心,倍增,dfs序
很容易想到一个贪心:不断删掉能被删除且权值最小的点,用堆实现
但很明显这是错的:局部最优 ( eq) 全局最优
因为点权为 (2^i) ,所以与其选择 (1, 2, cdots, i-1) 不如选 (i) 这一个点
考虑这样一个 (O(n^2)) 贪心:从大往小确定不被删的点。以 (n) 为根,若当前枚举到的节点可选则将根到这个点的路径上的所有点都选掉
可以发现,问题转化为了快速求出一个点是否可选,并标记路径
因为最多会标记 (n-1) 条路径,所以可以暴力标记,倍增查询
时间复杂度 (O(nlog n)) ,空间复杂度 (O(nlog n))
原操作还有另一种实现方式:树状数组维护dfs序上每个点的距离,标记路径时修改子树内答案
这种做法虽然略显繁琐,但是不失为一种常用的高效算法
时间复杂度 (O(nlog n)) ,空间复杂度 (O(n))
倍增代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 10;
int n, k, fa[21][maxn]; bool vis[maxn];
vector <int> e[maxn];
void dfs(int u, int f) {
fa[0][u] = f;
for (int i = 1; i < 21; i++) {
fa[i][u] = fa[i - 1][fa[i - 1][u]];
}
for (int v : e[u]) {
if (v != f) dfs(v, u);
}
}
int lca(int u) {
int res = 0;
for (int i = 20; ~i; i--) {
if (!vis[fa[i][u]]) {
u = fa[i][u], res |= 1 << i;
}
}
return res;
}
int main() {
scanf("%d %d", &n, &k);
for (int i = 1, u, v; i < n; i++) {
scanf("%d %d", &u, &v);
e[u].push_back(v), e[v].push_back(u);
}
dfs(n, 0), vis[0] = 1;
int tmp = n - k;
for (int i = n; i; i--) {
int t = lca(i);
if (t < tmp) {
int cur = i;
while (!vis[cur]) {
vis[cur] = 1, tmp--, cur = fa[0][cur];
}
}
}
for (int i = 1; i < n; i++) {
if (!vis[i]) printf("%d ", i);
}
return 0;
}
dfs序+树状数组代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 10;
int n, k, now, a[maxn], c[maxn], sz[maxn], fa[maxn], tid[maxn], dep[maxn]; bool vis[maxn];
vector <int> e[maxn];
void upd(int pos, int x) {
for (; pos <= n; pos += pos & -pos) {
c[pos] += x;
}
}
int query(int pos) {
int res = 0;
for (; pos; pos &= pos - 1) {
res += c[pos];
}
return res;
}
int dfs(int u, int f) {
a[++now] = u, tid[u] = now, fa[u] = f, dep[u] = dep[f] + 1;
for (int v : e[u]) {
if (v != f) sz[u] += dfs(v, u);
}
return ++sz[u];
}
int main() {
scanf("%d %d", &n, &k);
for (int i = 1, u, v; i < n; i++) {
scanf("%d %d", &u, &v);
e[u].push_back(v), e[v].push_back(u);
}
dfs(n, 0), upd(1, -1);
for (int i = 1; i <= n; i++) {
upd(i, dep[a[i]] - dep[a[i - 1]]);
}
int tmp = n - k; vis[0] = 1;
for (int i = n; i; i--) {
if (query(tid[i]) < tmp) {
for (int cur = i; !vis[cur]; tmp--, vis[cur] = 1, cur = fa[cur]) {
upd(tid[cur], -1), upd(tid[cur] + sz[cur], 1);
}
}
}
for (int i = 1; i < n; i++) {
if (!vis[i]) printf("%d ", i);
}
return 0;
}