一、题目
给定 \(n\) 个点的树,边有边权。每个点有一个种类 \(a_u\in\{0,1\}\),对于 \(a_u=0\),定义 \(ms(u)=\max_{a_v=1} w(u,v)\),其中 \(w(u,v)\) 表示 \((u,v)\) 路径上的最大边减去路径上的最小边。
要求把某个 \(a_u=0\) 变成 \(a_u=1\),最大化 \(\min ms(u)\),输出翻转点和最大值。
\(n\leq 10^5\)
二、解法
首先考虑怎么求出所有 \(ms(u)\),可以直接点分治,设 \(a_u/b_u\) 表示分治中心到点 \(u\) 的最大边\(/\)最小边,所以 \(w(u,v)=\max\{a_u,a_v\}-\min\{b_u,b_v\}\),这可以拆分成 \(\max\{a_u-b_u,a_v-b_v,a_u-b_v,a_v-b_u\}\),只需要要求 \(u,v\) 在不同子树即可。那么我们前缀后缀都扫一遍,维护一下最大值就可以方便地计算了。
考虑二分答案 \(L\),那么我们翻转一个点 \(x\) 之后,需要要求其他点 \(u\) 满足 \(\max(ms(u),w(x,u))\geq L\)
那么现在的问题变成了如何对每个可能的 \(x\) 判断 \(\forall u,w(x,u)\geq L\),还是考虑点分治,优化讨论这两种情形:
- \(a_x\geq a_u\),那么 \(w(x,u)=a_x-\min(b_u,b_x)\),记录最小的 \(b_u\) 即可。
- \(a_x\leq a_u\),那么 \(w(x,u)=a_u-\min(b_u,b_x)\),发现我们只需要考虑满足 \(a_u-b_u< L\) 的点 \(u\),此时判断 \(a_u-b_x\geq L\) 是否成立即可。那么记录最小的 \(a_u\),满足 \(a_x\leq a_u\) 并且 \(a_u-b_u<L\)
所以可以在第一次点分治中把分治子树内的点按 \(a\) 排好序,这样二分时只需要枚举分治中心,然后按顺序扫描两次,根据上面的讨论就可以判断了,时间复杂度 \(O(n\log n\log V)\)
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
const int M = 100005;
const int inf = 0x3f3f3f3f;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,tot,a[M],b[M],c[M],f[M],vis[M],dep[M];
int rt,sz,siz[M],zxy[M],mx[M][20],mn[M][20],ms[M];
struct edge{int v,c,next;}e[M<<1];
vector<int> dm[M],s;vector<vector<int>> son;
void upd(int &x,int y) {x=max(x,y);}
void find(int u,int fa)
{
siz[u]=1;zxy[u]=0;
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v,c=e[i].c;
if(v==fa || vis[v]) continue;
find(v,u);
siz[u]+=siz[v];
upd(zxy[u],siz[v]);
}
upd(zxy[u],sz-siz[u]);
if(zxy[rt]>zxy[u]) rt=u;
}
void dfs(int u,int fa,int d)
{
s.push_back(u);
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v,c=e[i].c;
if(v==fa || vis[v]) continue;
mx[v][d]=max(mx[u][d],c);
mn[v][d]=min(mn[u][d],c);
dfs(v,u,d);
}
}
void solve(int u,int d)
{
vis[u]=1;dep[u]=d;son.clear();
mx[u][d]=-inf;mn[u][d]=inf;
int l=0,A=-inf,B=-inf,C=-inf,D=a[u];
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v,c=e[i].c;
if(vis[v]) continue;
mx[v][d]=mn[v][d]=c;
s.clear();dfs(v,0,d);
son.push_back(s);l++;
for(int x:s) if(a[x]==0)
{
dm[u].push_back(x);
if(D) upd(ms[x],mx[x][d]-mn[x][d]);
upd(ms[x],A-mn[x][d]);
upd(ms[x],mx[x][d]+B);
upd(ms[x],C);
}
for(int x:s) if(a[x]==1)
{
upd(A,mx[x][d]);
upd(B,-mn[x][d]);
upd(C,mx[x][d]-mn[x][d]);
D=1;
}
}
if(!a[u]) upd(ms[u],C);
A=B=C=-inf;D=a[u];
for(int i=l-1;i>=0;i--)
{
s=son[i];
for(int x:s) if(a[x]==0)
{
if(D) upd(ms[x],mx[x][d]-mn[x][d]);
upd(ms[x],A-mn[x][d]);
upd(ms[x],mx[x][d]+B);
upd(ms[x],C);
}
for(int x:s) if(a[x]==1)
{
upd(A,mx[x][d]);
upd(B,-mn[x][d]);
upd(C,mx[x][d]-mn[x][d]);
D=1;
}
}
dm[u].push_back(u);
sort(dm[u].begin(),dm[u].end(),
[&](int i,int j){return mx[i][d]<mx[j][d];});
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(vis[v]) continue;
rt=0;sz=siz[v];find(v,u);
solve(rt,d+1);
}
}
int check(int L)
{
for(int i=1;i<=n;i++)
{
b[i]=a[i]==0;
c[i]=a[i]==0 && ms[i]<L;
}
for(int u=1;u<=n;u++)
{
int d=dep[u],l=dm[u].size();
int pre=-inf,suf=inf;
for(int i=0;i<l;i++)
{
int x=dm[u][i];
if(a[x]==0 && pre>-inf)
b[x]&=mx[x][d]-min(mn[x][d],pre)>=L;
if(c[x] && mx[x][d]-mn[x][d]<L)
upd(pre,mn[x][d]);
}
for(int i=l-1;i>=0;i--)
{
int x=dm[u][i];
if(a[x]==0 && suf<inf)
b[x]&=suf-mn[x][d]>=L;
if(c[x] && mx[x][d]-mn[x][d]<L)
suf=min(suf,mx[x][d]);
}
}
for(int i=1;i<=n;i++) if(b[i]) return i;
return 0;
}
signed main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
n=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read(),c=read();
e[++tot]=edge{v,c,f[u]},f[u]=tot;
e[++tot]=edge{u,c,f[v]},f[v]=tot;
}
zxy[0]=sz=n;find(1,0);solve(rt,0);
int l=0,r=inf,ans=0;
while(l<=r)
{
int mid=(l+r)>>1;
if(check(mid)) ans=mid,l=mid+1;
else r=mid-1;
}
printf("%d %d\n",check(ans),ans==inf?0:ans);
}