比较套路的题目。
题目大意
给出一棵(n)个点的树,树以(1)为根,有(m)条向上的路径,每个点有一个覆盖次数限制,你要选尽可能多的路径,使每个点被覆盖的次数都不超过限制。
Solution
首先把路径挂在较深的点,然后(dfs),用线段树维护一下这个点被覆盖的次数是否超限,若超限显然把上端最浅的路径删掉,因为这个路径影响的点最多,删掉这条路径相当于给上面的点“减负”了。
于是一个线段树合并,没了。
贪心的正确性比较显然。
Code
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 300007;
inline int read() {
int x = 0, f = 0;
char c = getchar();
for (; c < '0' || c > '9'; c = getchar()) if (c == '-') f = 1;
for (; c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) + (c ^ '0');
return f ? -x : x;
}
int n, m, c[N];
vector<int> lis[N];
int tot, st[N], nx[N << 1], to[N << 1], dep[N];
void add(int u, int v) {
to[++tot] = v, nx[tot] = st[u], st[u] = tot;
}
void dfs(int u) {
for (int i = st[u]; i; i = nx[i]) if (!dep[to[i]]) dep[to[i]] = dep[u] + 1, dfs(to[i]);
}
int cnt, root[N], sum[N * 50], lson[N * 50], rson[N * 50];
void ins(int& rt, int l, int r, int po) {
if (!rt) rt = ++cnt;
++sum[rt];
if (l == r) return;
int mid = l + r >> 1;
if (po <= mid) ins(lson[rt], l, mid, po);
else ins(rson[rt], mid + 1, r, po);
}
int qry(int rt, int l, int r, int ql, int qr) {
if (!rt) return 0;
if (ql <= l && r <= qr) return sum[rt];
int mid = l + r >> 1, ret = 0;
if (ql <= mid) ret += qry(lson[rt], l, mid, ql, qr);
if (mid + 1 <= qr) ret += qry(rson[rt], mid + 1, r, ql, qr);
return ret;
}
void del(int rt, int l, int r) {
if (!rt) return;
--sum[rt];
if (l == r) return;
int mid = l + r >> 1;
if (sum[lson[rt]] > 0) del(lson[rt], l, mid);
else del(rson[rt], mid + 1, r);
}
void merge(int x, int y) {
if (!x || !y) return;
sum[x] += sum[y];
if (lson[x] && lson[y]) merge(lson[x], lson[y]);
else if (lson[y]) lson[x] = lson[y];
if (rson[x] && rson[y]) merge(rson[x], rson[y]);
else if (rson[y]) rson[x] = rson[y];
}
void solve(int u, int from) {
int sz = lis[u].size();
for (int i = 0; i < sz; ++i) ins(root[u], 1, n, lis[u][i]);
for (int i = st[u]; i; i = nx[i]) if (to[i] != from) {
solve(to[i], u), merge(root[u], root[to[i]]);
}
while (qry(root[u], 1, n, 1, dep[u]) > c[u]) del(root[u], 1, n);
}
int main() {
freopen("fake.in", "r", stdin);
//freopen("fake.out", "w", stdout);
n = read(), m = read();
for (int i = 1; i <= n; ++i) c[i] = read(), root[i] = ++cnt;
for (int i = 1, u, v; i < n; ++i) u = read(), v = read(), add(u, v), add(v, u);
dep[1] = 1, dfs(1);
for (int i = 1, u, v; i <= m; ++i) {
u = read(), v = read();
if (dep[u] < dep[v]) swap(u, v);
lis[u].push_back(dep[v]);
}
solve(1, 0);
printf("%d
", sum[root[1]]);
return 0;
}