题意
给定一个(1)为根的树,每个点有(c,w)两个属性,你需要从某个点(u)子树里选择(k)个点,满足选出来的点(sum_{i=1}^k w(i)leq m),最大化(k imes c(u))
题解
可以启发式合并(splay)来做,( ext{dfs})每个点,每次和儿子的(splay)合并,就得到了一个维护这个点子树的平衡树。再用这个点的(w)(题目中的领导力)乘以子树中最多选多少结点满足(c)(薪水)和(leq m)
肯定贪心选,小的先选
然后这玩意我一开始脑抽写了二分(+kth)的两个(log)做法,会( ext{TLE})一个点,(97pts):
int kth(int u, int k) {
while(1) {
if(sz[ch[u][0]] >= k) u = ch[u][0];
else {
k -= sz[ch[u][0]] + 1;
if(k <= 0) break ;
u = ch[u][1];
}
}
splay(u);
return u;
}
ll qsum(int &u, int k) {
u = kth(u, k);
return sum[ch[u][0]] + w[u];
}
ll query(int &u) {
int l = 1, r = sz[u], ans = 0;
while(l <= r) {
int mid = (l + r) >> 1;
if(qsum(u, mid) <= m) l = (ans = mid) + 1;
else r = mid - 1;
}
return ans;
}
后来发现直接递归找就行了,十分简单,从左子树往自己往右子树依次判能不能选进去
最后的代码就是这样了qwq:
#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 2e5 + 10;
int n, m, w[N], c[N];
int fa[N], sz[N], ch[N][2];
vector<int> G[N];
ll ans, sum[N];
int dir(int u) { return ch[fa[u]][1] == u; }
void upd(int u) {
sum[u] = sum[ch[u][0]] + sum[ch[u][1]] + w[u];
sz[u] = sz[ch[u][0]] + sz[ch[u][1]] + 1;
}
void rotate(int u) {
int d = dir(u), f = fa[u];
if(fa[u] = fa[f]) ch[fa[u]][dir(f)] = u;
if(ch[f][d] = ch[u][d ^ 1]) fa[ch[f][d]] = f;
fa[ch[u][d ^ 1] = f] = u;
upd(f); upd(u);
}
void ins(int &rt, int u, int f = 0) {
if(!rt) {
rt = u; fa[u] = f;
sz[u] = 1; sum[u] = w[u];
ch[u][0] = ch[u][1] = 0;
return ;
}
ins(ch[rt][w[u] > w[rt]], u, rt);
upd(rt);
}
void splay(int u) {
for(; fa[u]; rotate(u)) if(fa[fa[u]])
rotate(dir(u) == dir(fa[u]) ? fa[u] : u);
}
ll query(int u, ll s) {
if(!u || !s) return 0;
if(sum[u] <= s) return sz[u];
if(sum[ch[u][0]] >= s) return query(ch[u][0], s);
if(sum[ch[u][0]] + w[u] == s) return sz[ch[u][0]] + 1;
if(sum[ch[u][0]] + w[u] > s) return sz[ch[u][0]];
return sz[ch[u][0]] + 1 + query(ch[u][1], s - sum[ch[u][0]] - w[u]);
}
int a[N], l;
void get(int u) {
if(u) {
get(ch[u][0]);
a[++ l] = u;
get(ch[u][1]);
}
}
int merge(int u, int v) {
if(sz[u] < sz[v]) swap(u, v);
l = 0; get(v); int rt = u;
for(int i = 1; i <= l; i ++) {
ins(rt, a[i]);
splay(rt = a[i]);
}
return rt;
}
int dfs(int u) {
int rt = u; sz[u] = 1; sum[u] = w[u];
fa[u] = ch[u][0] = ch[u][1] = 0;
for(int i = 0; i < G[u].size(); i ++)
rt = merge(rt, dfs(G[u][i]));
ans = max(ans, c[u] * query(rt));
return rt;
}
int main() {
scanf("%d%d", &n, &m);
for(int f, i = 1; i <= n; i ++) {
scanf("%d%d%d", &f, &w[i], &c[i]);
G[f].push_back(i);
}
dfs(1);
printf("%lld
", ans);
return 0;
}