https://www.luogu.com.cn/problem/CF1654G
有点厉害的题。
首先 \(h_i\) 可以通过 bfs 求出来。考虑最优方案是什么,一定是从当前的 \(x\) 的跳到最低的 \(y\),然后 \(y\) 的能操作 2。
先理理有什么性质:
-
倘若一个点 \(x\) 的高度是 \(h_x\),那么一定能操作一 \(h_x\) 次,即走这条链。
-
相邻的节点的高度差不超过 1。
-
倘若两个相邻的节点的高度相等,那么一定两个节点的两端一定有 2 条不相交的长度相等的链直到黑色节点。
下文中把相邻高度相等的节点叫做特殊点(即能操作 2 的)。
那么,如何处理出每个节点能到达的高度最小的特殊点呢?考虑暴力,从小到大枚举高度,看看是否有这个高度的特殊点,然后从这些点出发,令 \(a[x]\) 为 \(x\) 点到当前高度的特殊点中的最小所需储备能量。即想要从 \(x\) 点跳到特殊点在跳到 \(x\) 的时候至少要 \(a_x\) 的能量。显然特殊点都是 0,考虑类 bfs 转移。
-
\(h_x=h_y\),说明中间只能操作 2 到达,那么 y 的能量储备要更多。
-
\(h_x<h_y\),说明 \(y\) 可以通过 \((x,y)\) 多出 1 点能量,那么 \(a[y]=a[x]-1\)。
需要注意的是,任何时候都要满足 \(a\ge 0\),才能使得没有错误的转移(在相等时转移会出锅)。
分析到这里,会发现这个暴力应该是对的。
我们只要考虑所有特殊点的种类即可。
考虑构造点被特殊点两端的链覆盖尽可能多的次数,发现好像只能构造到
这个样子,其中最多 \(O(n)\) 个点被多覆盖到了 1 次,其他点都是只有 \(1\) 次覆盖的。也就是说 \(\sum_\limits{x是特殊点} h_x\) 的级别是 \(O(n)\) 的,那么种类是 \(O(\sqrt{n})\) 级别的,考虑只在 \([1,\sqrt{n}]\) 出现。
做完了。
#include <bits/stdc++.h>
//#define int long long
#define pb push_back
using namespace std;
const int N=(int)(2e5+5),inf=(int)(2e9);
queue<int>q;
vector<int>g[N],vec[N];
bool vis[N];
int n,h[N],mi[N],a[N];
signed main() {
cin.tie(0); ios::sync_with_stdio(false);
cin>>n; memset(h,0x3f,sizeof(h));
for(int i=1;i<=n;i++) {
int x; cin>>x; if(x) h[i]=0;
}
for(int i=1;i<n;i++) {
int x,y; cin>>x>>y;
g[x].pb(y); g[y].pb(x);
}
// for(int i=1;i<=n;i++) {
// cout<<h[i]<<endl;
// }
for(int i=1;i<=n;i++) if(!h[i]) q.push(i);
while(!q.empty()) { //纯正 bfs,是吧
int x=q.front(); q.pop();
for(int y:g[x]) {
if(h[y]>h[x]+1) {
h[y]=h[x]+1; q.push(y);
}
}
}
// for(int i=1;i<=n;i++) {
// cout<<h[i]<<endl;
// }
for(int i=0;i<=n;i++) {
bool fl=0;
for(int y:g[i]) {
if(h[y]==h[i]) {
fl=1;
}
}
if(fl) vis[h[i]]=1,vec[h[i]].pb(i);
}
memset(mi,0x3f,sizeof(mi));
for(int i=0;i<=n;i++) {
if(!vis[i]) continue ;
// cout<<i<<endl;
for(int j=1;j<=n;j++) a[j]=inf;
for(int x:vec[i]) a[x]=0,mi[x]=min(mi[x],i),q.push(x);
while(!q.empty()) {
int x=q.front(); q.pop();
for(int y:g[x]) {
if(h[y]==h[x]&&a[y]>a[x]+1) {
a[y]=a[x]+1; q.push(y);
} else if(h[y]>h[x]&&a[y]>a[x]-1) {
a[y]=max(0,a[x]-1); q.push(y);
}
}
}
for(int j=1;j<=n;j++) if(a[j]<=0) mi[j]=min(mi[j],i);
}
for(int i=1;i<=n;i++) {
if(!h[i]) cout<<"0 ";
else if(mi[i]>0x3f3f3f) cout<<h[i]<<' ';
else cout<<2*h[i]-mi[i]<<' ';
}
// cout<<'\n';
// for(int i=1;i<=n;i++) cout<<mi[i]<<' ';
return 0;
}
/*
好难
题解都看不懂,只能自己写了
/yun
明天一模了,what should I do?
*/