一道非常良心卡空间的图论题,比赛的时候我一交直接MLE了。(然后才发现自己多按了个0)
题目中说要将所有的白点连向最近的黑点。在此之前,我们需要先将多余的边剔除掉 —— $m$居然是$n$的两倍,这能不剔边吗!
怎么剔呢?考虑增加一个超级原点,编号 $n + 1$, 最开始向所有黑点连一个$w = 0$的有向边.
接着,我们从超级原点跑一次最短路 ($Dyj$ 或者 $Spfa$ 随意), 然后我们就能得到一张最短路的图 -> 然后按照这个图建一张新的、精简过的图。 新边连接要求: $dis_v == dis_u + w$ -> 这样就能保证我们所连的边都是尽量短的。
当然,在这里要用到dfs来加边。如果刚刚我们的超级原点加了无向边,那么在这里会死循环,我们需要一个 $vis$数组来辅助。有向边则不需要。
然后,在新图上,如何保证连向最小的边? 最小生成树呗。 对新图一波 $Kruskal $ 就好了。
#include <bits/stdc++.h> using namespace std; #define N 200100 #define ll long long inline int read(){ int x = 0, s = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') s = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + (c ^ '0'); c = getchar(); } return x * s; } struct node{ int u, v; ll w; int next; } ; node t[N << 2]; node e[N << 2]; int f[N], head[N]; int n, m; int bian = 0; inline void add(int u, int v, ll w){ t[++bian] = (node){u, v, w, f[u]}, f[u] = bian; return ; } int bian2 = 0; inline void add1(int u, int v, ll w){ e[++bian2] = (node){u, v, w, head[u]}, head[u] = bian2; // e[++bian2] = (node){v, u, w, head[v]}, head[v] = bian2; return ; } /*基本操作 end*/ // 最短路 begin struct point{ ll u, dis; bool operator < (const point& a) const{ return dis > a.dis; } } ; ll d[N]; bool vis[N]; priority_queue <point> q; #define v t[i].v void Dyj(int s){ for(int i = 1;i <= n; i++) d[i] = (ll)1e18; d[s] = 0; q.push((point){s, 0}); while(!q.empty()){ point temp = q.top(); q.pop(); int now = temp.u; if(!vis[now]){ vis[now] = 1; for(int i = f[now]; i; i = t[i].next){ if(d[v] > d[now] + t[i].w){ d[v] = d[now] + t[i].w; if(!vis[v]) q.push((point){v, d[v]}); } } } } return ; } // 最短路 end // 最小生成树 begin void dfs(int now){ // 连新边 for(int i = f[now]; i; i = t[i].next){ if(d[v] == d[now] + t[i].w){ dfs(v); add1(now, v, t[i].w); } } return ; } #undef v bool cmp(node a, node b){ return a.w < b.w; } int fa[N]; int find(int x){ return x == fa[x] ? x : fa[x] = find(fa[x]); } void kruskal(){ sort(e + 1, e + bian2 + 1, cmp); for(int i = 1; i <= n + 1; i++) fa[i] = i; ll ans = 0; for(int i = 1;i <= bian2; i++){ int x = find(e[i].u), y = find(e[i].v); if(x != y){ ans += e[i].w; fa[x] = y; } } if(ans == 0) printf("impossible "); else printf("%lld ", ans); return ; } int main(){ // freopen("10.in", "r", stdin); // freopen("10.out", "w", stdout); n = read(), m = read(); for(int i = 1;i <= n; i++){ int x = read(); if(x == 1) add(n + 1, i, 0); /*必须加单向边,或者在dfs的时候要加 vis*/ } for(int i = 1;i <= m; i++){ int x = read(), y = read(), w = read(); add(x, y, w); add(y, x, w); } Dyj(n + 1); dfs(n + 1); kruskal(); return 0; }