https://www.luogu.org/problemnew/show/P4180
先吐槽一波垃圾蓝书,n^2logn有个矩阵用
先求出最小生成树,显然的,次小生成树应该是向MST中删一条树边加一条非树边
设加入的边连接u,v,为保证次小,应该删去MST中u->v路径上的边权最大值
考虑一个细节:如果加入的边和u->v上的边权最大值相同,此时应删去次大值
所以我们要完成这些任务:求树上两点之间边权的最大值、次大值
LCA即可
哦对inf要开大,因为tot最大可达10 ^ 14,开小了会WA掉50%...
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<map> #include<cmath> using namespace std; #define LM qwq typedef int mainint; #define int long long inline int read() { int ans = 0,op = 1; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') op = -1; ch = getchar(); } while(ch >= '0' && ch <= '9') { (ans *= 10) += ch - '0'; ch = getchar(); } return ans * op; } #define inf 99999999999999 const int maxn = 1e5 + 5; int n,m; struct edge { int to,next,cost; }e[maxn * 6]; struct ed { int u,v,w; bool operator < (const ed& b) const { return w < b.w; } }a[maxn * 6]; bool vis[maxn * 6]; int fir[maxn],alloc; void adde(int u,int v,int w) { e[++alloc].next = fir[u]; fir[u] = alloc; e[alloc].to = v; e[alloc].cost = w; swap(u,v); e[++alloc].next = fir[u]; fir[u] = alloc; e[alloc].to = v; e[alloc].cost = w; } int par[maxn]; int find(int x) { return x == par[x] ? x : par[x] = find(par[x]); } void merge(int x,int y) { x = find(x),y = find(y); par[x] = y; } bool same(int x,int y) { return find(x) == find(y); } int tot; void kruskal() { for(int i = 1;i <= n;i++) par[i] = i; sort(a + 1,a + 1 + m); for(int i = 1;i <= m;i++) { int u = a[i].u,v = a[i].v; if(same(u,v)) continue; merge(u,v); adde(u,v,a[i].w); vis[i] = 1; tot += a[i].w; } } int dep[maxn],f[maxn][21],mx[maxn][21],mi[maxn][21]; void dfs(int u,int fa) { dep[u] = dep[fa] + 1; for(int i = 1;i <= 20;i++) { f[u][i] = f[f[u][i - 1]][i - 1]; mx[u][i] = max(mx[u][i - 1],mx[f[u][i - 1]][i - 1]); mi[u][i] = max(mi[u][i - 1],mi[f[u][i - 1]][i - 1]); if(mx[u][i - 1] > mx[f[u][i - 1]][i - 1]) mi[u][i] = max(mi[u][i],mx[f[u][i - 1]][i - 1]); else if(mx[u][i - 1] < mx[f[u][i - 1]][i - 1]) mi[u][i] = max(mi[u][i],mx[u][i - 1]); } for(int i = fir[u];i;i = e[i].next) { int v = e[i].to,w = e[i].cost; if(v == fa) continue; mx[v][0] = w; mi[v][0] = -inf; f[v][0] = u; dfs(v,u); } } int lca(int x,int y) { if(dep[x] < dep[y]) swap(x,y); for(int i = 20;i >= 0;i--) { if(dep[f[x][i]] >= dep[y]) x = f[x][i]; if(x == y) return x; } for(int i = 20;i >= 0;i--) if(f[x][i] != f[y][i]) x = f[x][i],y = f[y][i]; return f[x][0]; } int get_max(int u,int v,int val) { int ans = -inf; for(int i = 20;i >= 0;i--) { if(dep[f[u][i]] >= dep[v]) { if(mx[u][i] != val) ans = max(ans,mx[u][i]); else ans = max(ans,mi[u][i]); u = f[u][i]; } } return ans; } mainint main() { n = read(),m = read(); for(int i = 1;i <= m;i++) a[i].u = read(),a[i].v = read(),a[i].w = read(); kruskal(); mi[1][0] = -inf; dfs(1,0); int ans = inf; for(int i = 1;i <= m;i++) { if(vis[i]) continue; int u = a[i].u,v = a[i].v,w = a[i].w; int fa = lca(u,v); //printf("%d %d %d ",u,v,fa); int maxm = max(get_max(u,fa,w),get_max(v,fa,w)); //cout << maxm << endl; ans = min(ans,tot - maxm + w); } printf("%lld",ans); }