正题
题目链接:https://www.luogu.com.cn/problem/AT3611
题目大意
给出\(n\)个点的一棵树。
现在有一张完全图,两个点之间的边权为\(w_x+w_y+dis(x,y)\)(\(dis\)表示树上距离)
求这张完全图的最小生成树。
\(2\leq n\leq 2\times 10^5,1\leq w_i,c_i\leq 10^9\)
解题思路
考虑可能作为最小生成树的边。
一个结论就是对于一个子图。不在最小生成森林上的边一定不在原图的最小生成树上。
这样可以考虑分治,点分治之后对于根节点\(x\),其他的节点定义\(f_x=dep_x+w_x\),那么两个点之间边权就是\(f_x+f_y\)了(\(x,y\)属于不同子树),对于同一子树的我们也加进去,因为这是不优的边所以不会影响答案。
此时图中的最小生成森林是其他所有点连接\(f\)值最小的点。
这样我们可以处理出\(n\log n\)条可能的边,在这些边上再跑一次最小生成树就好了。
时间复杂度\(O(n\log^2 n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=2e5+10,inf=1e18;
struct node{
ll to,next,w;
}a[N<<1];
struct edge{
ll x,y,w;
}e[N<<5];
ll n,tot,mins,root,ans,num,ent;
ll ls[N],f[N],siz[N],w[N],fa[N];
bool v[N];
void addl(ll x,ll y,ll w){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;a[tot].w=w;
return;
}
void groot(ll x,ll fa){
siz[x]=1;f[x]=0;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa||v[y])continue;
groot(y,x);siz[x]+=siz[y];
f[x]=max(f[x],siz[y]);
}
f[x]=max(f[x],num-siz[x]);
if(f[x]<f[root])root=x;
return;
}
void calc(ll x,ll fa,ll dep){
f[x]=w[x]+dep;
if(f[x]<f[mins])mins=x;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa||v[y])continue;
calc(y,x,dep+a[i].w);
}
return;
}
void adde(ll x,ll fa){
e[++ent]=(edge){x,mins,f[x]+f[mins]};
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa||v[y])continue;
adde(y,x);
}
}
void solve(ll x){
v[x]=1;f[x]=w[mins=x];
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(v[y])continue;
calc(y,x,a[i].w);
}
e[++ent]=(edge){x,mins,f[x]+f[mins]};
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(v[y])continue;
adde(y,x);
}
ll sum=num;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(v[y])continue;
num=(siz[y]>siz[x])?(sum-siz[x]):siz[y];
root=0;groot(y,x);solve(root);
}
return;
}
bool cmp(edge x,edge y)
{return x.w<y.w;}
ll find(ll x)
{return (fa[x]==x)?x:(fa[x]=find(fa[x]));}
signed main()
{
scanf("%lld",&n);
for(ll i=1;i<=n;i++)
scanf("%lld",&w[i]),fa[i]=i;
for(ll i=1;i<n;i++){
ll x,y,w;
scanf("%lld%lld%lld",&x,&y,&w);
addl(x,y,w);addl(y,x,w);
}
num=n;f[0]=inf;
groot(1,1);solve(root);
sort(e+1,e+1+ent,cmp);
for(ll i=1;i<=ent;i++){
ll x=e[i].x,y=e[i].y;
x=find(x);y=find(y);
if(x!=y)ans+=e[i].w,fa[y]=x;
}
printf("%lld\n",ans);
return 0;
}