通过lca来计算两个点之间的最大最小值,这样比暴力要快
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=3e5+10; const int inf=0x3f3f3f3f; int h[N],ne[N],e[N],w[N],idx; int p[N]; int depth[N]; int f[N][20]; int d1[N][20],d2[N][20]; int n,m; int dis[N]; void add(int a,int b,int c){ e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++; } struct node{ int a,b,c; int f; bool operator <(const node &t) const{ return c<t.c; } }s[N]; int find(int x){ if(x!=p[x]){ p[x]=find(p[x]); } return p[x]; } ll kruscal(){ sort(s+1,s+1+m); ll res=0; for(int i=1;i<=m;i++){ int pa=find(s[i].a),pb=find(s[i].b); if(pa!=pb){ p[pa]=pb; res+=s[i].c; s[i].f=1; } } return res; } void build(){ int i; memset(h,-1,sizeof h); for(i=1;i<=m;i++){ if(s[i].f){ add(s[i].a,s[i].b,s[i].c); add(s[i].b,s[i].a,s[i].c); } } } void bfs(){ memset(depth,0x3f,sizeof depth); int i; depth[0]=0; depth[1]=1; queue<int> q; q.push(1); while(q.size()){ int t=q.front(); q.pop(); for(i=h[t];i!=-1;i=ne[i]){ int j=e[i]; if(depth[j]>depth[t]+1){ depth[j]=depth[t]+1; q.push(j); f[j][0]=t; int k; d1[j][0]=w[i],d2[j][0]=-inf; for(k=1;k<=18;k++){ int pa=f[j][k-1]; int distance[4]={d1[j][k-1],d2[j][k-1],d1[pa][k-1],d2[pa][k-1]}; f[j][k]=f[pa][k-1]; d1[j][k] =d2[j][k]=-inf; for(int u=0;u<4;u++){ if(distance[u]>d1[j][k]){ d2[j][k]=d1[j][k]; d1[j][k]=distance[u]; } else if(distance[u]!=d1[j][k]&&distance[u]>d2[j][k]){ d2[j][k]=distance[u]; } } } } } } } int lca(int a,int b,int c){ if(depth[a]<depth[b]) swap(a,b); int i; int cnt=0; for(i=18;i>=0;i--){ if(depth[f[a][i]]>=depth[b]){ dis[cnt++]=d1[a][i]; dis[cnt++]=d2[a][i]; a=f[a][i]; } } if(a!=b){ for(i=18;i>=0;i--){ if(f[a][i]!=f[b][i]){ dis[cnt++]=d1[a][i]; dis[cnt++]=d2[a][i]; dis[cnt++]=d1[b][i]; dis[cnt++]=d2[b][i]; a=f[a][i]; b=f[b][i]; } } dis[cnt++]=d1[a][0]; dis[cnt++]=d1[b][0]; } int df=-inf,ds=-inf; for(i=0;i<cnt;i++){ if(dis[i]>df){ ds=df,df=dis[i]; } else if(dis[i]!=df&&dis[i]>ds){ ds=dis[i]; } } if(c>df) return c-df; else return c-ds; } int main(){ int i; cin>>n>>m; for(i=0;i<=n;i++) p[i]=i; for(i=1;i<=m;i++){ int a,b,c; scanf("%d%d%d",&a,&b,&c); s[i]=node{a,b,c}; } ll sum=kruscal(); build(); bfs(); ll res=1e18; for(i=1;i<=m;i++){ if(!s[i].f){ int a=s[i].a,b=s[i].b,c=s[i].c; res=min(res,sum+lca(a,b,c)); } } cout<<res<<endl; }