一、题目
二、解法
设 (dp[u][0/1]) 表示解决 (u) 子树内所有问题,(u) 的父边选不选的方案数,转移的时候把 (dp[v][1]+w-dp[v][0]) 从小到大排序,然后取一个前缀让 (u) 满足限制即可。
难点就是要对所有 (x) 求出答案,首先发现 (d[u]leq x) 的点 (u) 是没有用的,因为它一定合法。
那么我们从小到大枚举 (x),每次找出这样的点 (u),因为它已经合法所以可以把它直接删除,但它连的边可能还有用,我们直接把这条边的边权塞进 (v) 的大根堆中,这个堆就维护最优的删边方案,如果大小足够就把堆顶弹出即可。
然后原图就分成了若干个连通块,我们对于每个连通块分别 (dp) 即可,考虑 (v) 的转移时也把选择边 ((u,v)) 的代价塞进堆中,然后用堆决策即可。注意 (dp) 完了之后还要还原堆,因为要支持删除所以我手写了,挺好玩的。
(dp) 的时候注意只能访问度数大于等于 (x) 的点,所以我们要把每个点的边按终点的度数大小排序。
因为每个点只会被 (dp) 度数次,所以总时间复杂度是度数的累加,时间复杂度 (O(nlog n))
三、总结
把时间复杂度和某些量扯上关系,考虑和原问题紧密相关的量。
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
#define int long long
const int M = 250005;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,sum,t,d[M],id[M],vis[M],dp[M][2];
struct node
{
int v,c;
bool operator < (const node &b) const
{
return d[v]>d[b.v];
}
};vector<node> g[M];
struct heap
{
vector<int> v;
heap() {v.clear();}
void push(int x) {v.push_back(x);push_heap(v.begin(),v.end());}
void pop() {pop_heap(v.begin(),v.end());v.pop_back();}
int top() {return v[0];}
int sz() {return v.size();}
};
struct zxy
{
heap a,b;int sum,s;
zxy() {sum=s=0;}
void push(int x) {a.push(x);sum+=x;s++;}
void del(int x) {b.push(x);sum-=x;s--;}
void work() {while(a.sz() && b.sz() && a.top()==b.top()) a.pop(),b.pop();}
void pop() {work();s--;sum-=a.top();a.pop();}
int top() {work();return a.top();}
int sz() {return s;}
}h[M];
bool cmp(int x,int y)
{
return d[x]<d[y];
}
void era(int u)
{
for(node x:g[u])
{
if(d[x.v]<=t) break;
h[x.v].push(x.c);
}
}
void dfs(int u)
{
dp[u][0]=dp[u][1]=0;vis[u]=t;
int nd=d[u]-t,tmp=0;vector<int> v1,v2;
while(h[u].sz()>nd) h[u].pop();
for(node x:g[u])
{
if(d[x.v]<=t) break;
if(vis[x.v]==t) continue;
dfs(x.v);
int c=dp[x.v][1]+x.c-dp[x.v][0];
if(c<=0) {nd--;tmp+=dp[x.v][1]+x.c;continue;}
h[u].push(c);tmp+=dp[x.v][0];v1.push_back(c);
}
for(;h[u].sz() && h[u].sz()>nd;h[u].pop()) v2.push_back(h[u].top());
dp[u][0]=tmp+h[u].sum;
for(;h[u].sz() && h[u].sz()>nd-1;h[u].pop()) v2.push_back(h[u].top());
dp[u][1]=tmp+h[u].sum;
while(v2.size()) h[u].push(v2.back()),v2.pop_back();
while(v1.size()) h[u].del(v1.back()),v1.pop_back();
}
signed main()
{
n=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read(),w=read();
g[u].push_back(node{v,w});
g[v].push_back(node{u,w});
d[u]++;d[v]++;sum+=w;
}
for(int i=1;i<=n;i++)
id[i]=i,sort(g[i].begin(),g[i].end());
sort(id+1,id+1+n,cmp);
printf("%lld",sum);int i=1;
for(t=1;t<n;t++)
{
while(i<=n && d[id[i]]<=t) era(id[i++]);
int ans=0;
for(int j=i;j<=n;j++)
{
if(vis[id[j]]==t) continue;
dfs(id[j]);
ans+=dp[id[j]][0];
}
printf(" %lld",ans);
}
}