题目
题目链接:https://codeforces.com/problemset/problem/1119/F
双倍经验:https://www.luogu.com.cn/problem/P7600
有一个 (n) 个结点的树,每条边有边权,结点度数就是与之相连的边数量。对于 (0 le x < n),删掉一些边使每个结点的度数不大于 (x),求出删掉的边的权值和最小值。
(nleq 2.5 imes 10^5)。
思路
为 APIO 出题人辛苦地精心准备高质量原创题目行为点赞!!!/tuu
不提这一点,我写 (O(n^2log n)) 暴力,数组都是 (O(n)) 的一直告诉我 MLE,吐了。
设 (f[x][0/1]) 表示点 (x) 为根的子树内,其中 (x) 与父亲的边不删 / 删时,要求每个点度数不超过 (k) 的最小代价。
暴力的话就直接先钦定 (x) 与它所有儿子的边都是不删除的,然后删除其中最少 ( ext{deg}[x]-k) 条边。
那么直接把所有儿子 (y) 的 (f[y][1]+ ext{dis}[x][y]-f[y][0]) 取最小的 ( ext{deg}[x]-k) 个删除,其他都不删除。
时间复杂度 (O(n^2log n))。
观察到对于点 (x),当 (kgeq ext{deg}[x]) 的时候,(x) 就无关度数限制了,可以任意调整 (x) 的所有边,最小化 (x) 旁边的点的代价。
我们对于每一个点维护一个堆(用处与上文暴力的堆一样),当求解到要求所有点度数不超过 (k) 时,把 ( ext{deg}[x]=k) 的所有 (x) 拎出来,对于一条边 ((x,y)),把 ( ext{dis}[x][y]) 扔进 (y) 的堆中。
然后枚举所有度数大于 (k) 的点跑树形 dp,注意不用便利度数不超过 (k) 的点。这个可以用 vector 存图然后给出边排个序。
注意每次跑完一个点后需要还原这个点的堆,也就是堆中只保留其他度数不超过 (k) 的点到这个点的距离。所以维护两个堆实现可删除堆即可。
时间复杂度 (O((sum^{n}_{i=1} deg_i)log n)=O(nlog n))。
细节确实有点多,建议阅读代码理解。
代码
#include <bits/stdc++.h>
#define mp make_pair
#define ST first
#define ND second
using namespace std;
typedef long long ll;
const int N=250010;
const ll Inf=8e18;
int n,deg[N],vis[N],id[N];
ll ans,f[N][2],sum[N];
priority_queue<ll> q[N],p[N];
vector<ll> a,b;
vector<pair<int,ll> > e[N];
bool cmp(int x,int y)
{
return deg[x]<deg[y];
}
bool cmp1(pair<int,ll> x,pair<int,ll> y)
{
return deg[x.ST]>deg[y.ST];
}
void add(int from,int to,ll dis)
{
e[from].push_back(mp(to,dis));
deg[to]++;
}
void Erase(int x,ll v=Inf)
{
if (v!=Inf) p[x].push(v),sum[x]-=v;
while (q[x].size() && p[x].size() && q[x].top()==p[x].top())
q[x].pop(),p[x].pop();
}
void Push(int x,ll v)
{
q[x].push(v); sum[x]+=v; Erase(x);
}
void Pop(int x)
{
sum[x]-=q[x].top(); q[x].pop(); Erase(x);
}
void dfs(int x,int fa,int k)
{
vis[x]=k; f[x][0]=f[x][1]=0;
for (int i=0;i<e[x].size();i++)
{
int v=e[x][i].ST;
if (deg[v]<=k) break;
if (v!=fa) dfs(v,x,k);
}
a.clear(); b.clear();
int cnt=deg[x]-k;
while (q[x].size()-p[x].size()>cnt) Pop(x);
for (int i=0;i<e[x].size();i++)
{
int v=e[x][i].ST; ll dis=e[x][i].ND;
if (deg[v]<=k) break;
if (v!=fa)
{
if (f[v][0]<f[v][1]+dis)
{
f[x][0]+=f[v][0];
Push(x,f[v][1]+dis-f[v][0]);
a.push_back(f[v][1]+dis-f[v][0]);
}
else cnt--,f[x][0]+=f[v][1]+dis;
}
}
f[x][1]=f[x][0];
for (;q[x].size()-p[x].size()>max(cnt,0);Pop(x))
b.push_back(q[x].top());
f[x][0]+=sum[x];
for (;q[x].size()-p[x].size()>max(cnt-1,0);Pop(x))
b.push_back(q[x].top());
f[x][1]+=sum[x];
for (int i=0;i<b.size();i++) Push(x,b[i]);
for (int i=0;i<a.size();i++) Erase(x,a[i]);
}
int main()
{
memset(vis,-1,sizeof(vis));
scanf("%d",&n);
for (int i=1,x,y,z;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
for (int i=1;i<=n;i++)
{
id[i]=i;
sort(e[i].begin(),e[i].end(),cmp1);
}
sort(id+1,id+1+n,cmp);
for (int i=0,j=1;i<n;i++)
{
for (;j<=n && deg[id[j]]<=i;j++)
{
int x=id[j];
for (int k=0;k<e[x].size();k++)
{
int v=e[x][k].ST;
if (deg[v]<=i) break;
Push(v,e[x][k].ND);
}
}
ans=0;
for (int k=j;k<=n;k++)
if (vis[id[k]]!=i)
{
dfs(id[k],0,i);
ans+=f[id[k]][0];
}
cout<<ans<<' ';
}
return 0;
}