▶ 书中第六章部分程序,加上自己补充的代码,包括全局最小切分 Stoer-Wagner 算法,最小权值二分图匹配
● 全局最小切分 Stoer-Wagner 算法
1 package package01; 2 3 import edu.princeton.cs.algs4.In; 4 import edu.princeton.cs.algs4.StdOut; 5 import edu.princeton.cs.algs4.EdgeWeightedGraph; 6 import edu.princeton.cs.algs4.Edge; 7 import edu.princeton.cs.algs4.UF; 8 import edu.princeton.cs.algs4.IndexMaxPQ; 9 10 public class class01 11 { 12 private static final double FLOATING_POINT_EPSILON = 1E-11; 13 private double weight = Double.POSITIVE_INFINITY; // 输出的最小割值 14 private boolean[] cut; // 顶点是否在割除集 T 中 15 private int V; 16 17 private class CutPhase // 最小 s-t 割类(cut-of-the-phase) 18 { 19 private double weight; 20 private int s; 21 private int t; 22 23 public CutPhase(double inputWeight, int inputS, int inputT) 24 { 25 weight = inputWeight; 26 s = inputS; 27 t = inputT; 28 } 29 } 30 31 public class01(EdgeWeightedGraph G) 32 { 33 UF uf = new UF(G.V()); // 用类 union–find 来表示顶点的合并情况 34 boolean[] marked = new boolean[G.V()]; // 已合并的顶点集,初始化为空 35 cut = new boolean[G.V()]; // 割除集 T,初始化为空 36 CutPhase cp = new CutPhase(0.0, 0, 0); // 用于首次搜索的割,无意义 37 for (int v = G.V(); --v > 0; marked[cp.t] = true) // 遍历 V-1 次,每次标记被合并的顶点,以后不再遍历该点 38 { 39 cp = minCutPhase(G, marked, cp); // 更新最小割 40 if (cp.weight < weight) // 发现权值更小的割,更新全局割 41 { 42 weight = cp.weight; 43 for (int j = 0; j < G.V(); j++) // 在最新的图中,与 cp.t 相连的顶点都是 T 的元素 44 cut[j] = uf.connected(j, cp.t); 45 } 46 G = contractEdge(G, cp.s, cp.t); // 顶点 t 合并到顶点 s,更新图 47 uf.union(cp.s, cp.t); // 顶点 t 加入 T 中 48 } 49 } 50 51 public double weight() 52 { 53 return weight; 54 } 55 56 public boolean cut(int v) 57 { 58 return cut[v]; 59 } 60 61 private CutPhase minCutPhase(EdgeWeightedGraph G, boolean[] marked, CutPhase cp) // 计算 s - t 的最小割,称为 maximum adjacency (cardinality) search 62 { 63 IndexMaxPQ<Double> pq = new IndexMaxPQ<Double>(G.V()); // 用于挑选权值最大的点用于合并 64 pq.insert(cp.s, Double.POSITIVE_INFINITY); // 顶点 s 自己的权值为 +∞ 65 for (int v = 0; v < G.V(); v++) // 其他顶点权值初始化为 0 66 { 67 if (v != cp.s && !marked[v]) 68 pq.insert(v, 0.0); 69 } 70 for (; !pq.isEmpty();) 71 { 72 cp.s = cp.t; // 记录最后取出的两个顶点,每次往前挪一格 73 cp.t = pq.delMax(); // 取走权值最大的顶点 v 74 for (Edge e : G.adj(cp.t)) // 只要 cp.t 的邻居还在 pq 中没有被取走,就更新邻居的权值 75 { 76 int w = e.other(cp.t); 77 if (pq.contains(w)) 78 pq.increaseKey(w, pq.keyOf(w) + e.weight()); // 顶点 w 的权值自增边 e 的权值 79 } 80 } 81 cp.weight = 0.0; // 计算最后加入 T 的顶点的权值,即为最小割值 82 for (Edge e : G.adj(cp.t)) 83 cp.weight += e.weight(); 84 return cp; 85 } 86 87 private EdgeWeightedGraph contractEdge(EdgeWeightedGraph G, int s, int t) // 把顶点 t 合并到顶点 s,更新其他边的权值 88 { 89 EdgeWeightedGraph H = new EdgeWeightedGraph(G.V()); // 合并后的图,顶点与原来相同 90 for (int v = 0; v < G.V(); v++) 91 { 92 for (Edge e : G.adj(v)) // 依顶点序遍历所有边 93 { 94 int w = e.other(v); 95 if (v == s && w == t || v == t && w == s) // 边 s-t 自身不要 96 continue; 97 if (v < w) // 只考虑后向边,滤掉重复 98 { 99 if (w == t) // 远端顶点 w 是被合并的 t,边 v - w(t) 替换成边 v - s 100 H.addEdge(new Edge(v, s, e.weight())); 101 else if (v == t) // 近端顶点 v 是被合并的 t,边 v(t) - w 替换成边 s - w 102 H.addEdge(new Edge(w, s, e.weight())); 103 else // 边 v - w 与 s 或 t 无关,原样放进 H 104 H.addEdge(new Edge(v, w, e.weight())); 105 } 106 } 107 } 108 return H; 109 } 110 111 public static void main(String[] args) 112 { 113 In in = new In(args[0]); 114 EdgeWeightedGraph G = new EdgeWeightedGraph(in); 115 class01 mc = new class01(G); 116 StdOut.print("Min cut: "); 117 for (int v = 0; v < G.V(); v++) 118 { 119 if (mc.cut(v)) 120 StdOut.print(v + " "); 121 } 122 StdOut.println(" Min cut weight = " + mc.weight()); 123 } 124 }
● 最小权值二分图匹配
1 package package01; 2 3 import edu.princeton.cs.algs4.StdOut; 4 import edu.princeton.cs.algs4.StdRandom; 5 import edu.princeton.cs.algs4.EdgeWeightedDigraph; 6 import edu.princeton.cs.algs4.DirectedEdge; 7 import edu.princeton.cs.algs4.DijkstraSP; 8 9 public class class01 10 { 11 private static final double FLOATING_POINT_EPSILON = 1E-14; 12 private int n; // 二分图一侧的顶点数,总顶点数2 * n 13 private double[][] weight; // 边权矩阵 14 private double minWeight; // 最小权值 15 private double[] px; // 每一行的对偶变量 16 private double[] py; // 每一列的对偶变量 17 private int[] xy; // 正向标记,xy[i] = j 表示 i-j 匹配 18 private int[] yx; // 反向标记,yx[j] = i 表示 i-j 匹配 19 20 public class01(double[][] inputWeight) 21 { 22 n = inputWeight.length; 23 weight = new double[n][n]; 24 minWeight = Double.MAX_VALUE; 25 for (int i = 0; i < n; i++) 26 { 27 for (int j = 0; j < n; j++) 28 { 29 if (Double.isNaN(weight[i][j])) 30 throw new IllegalArgumentException("weight " + i + "-" + j + " is NaN"); 31 weight[i][j] = inputWeight[i][j]; 32 minWeight = Math.min(minWeight, weight[i][j]); 33 } 34 } 35 px = new double[n]; 36 py = new double[n]; 37 xy = new int[n]; 38 yx = new int[n]; 39 for (int i = 0; i < n; i++) 40 xy[i] = yx[i] = -1; 41 42 for (int k = 0; k < n; k++) // 调整 n 次,每次添加一条边 43 { 44 assert isDualFeasibleAndComplementarySlack(); 45 augment(); 46 } 47 assert isDualFeasibleAndComplementarySlack() && isPerfectMatching(); // 检查结果正确性,要求互补松弛,完美匹配 48 } 49 50 private void augment() // 寻找最小权值路径并更新 51 { 52 EdgeWeightedDigraph G = new EdgeWeightedDigraph(2 * n + 2); 53 int s = 2 * n, t = 2 * n + 1; 54 for (int i = 0; i < n; i++) // 所有未匹配的顶点连到 s 和 t 上,s 侧权值为 0,t 侧有权值 55 { 56 if (xy[i] == -1) 57 G.addEdge(new DirectedEdge(s, i, 0.0)); 58 } 59 for (int j = 0; j < n; j++) 60 { 61 if (yx[j] == -1) 62 G.addEdge(new DirectedEdge(n + j, t, py[j])); 63 } 64 for (int i = 0; i < n; i++) 65 { 66 for (int j = 0; j < n; j++) // 已匹配的顶点对添加权值为 0 的反向边,未匹配的顶点对添加修正权值的正向边 67 { 68 if (xy[i] == j) 69 G.addEdge(new DirectedEdge(n + j, i, 0.0)); 70 else 71 G.addEdge(new DirectedEdge(i, n + j, reducedCost(i, j))); 72 } 73 } 74 DijkstraSP spt = new DijkstraSP(G, s); // 计算从 s 到各顶点最短距离 75 for (DirectedEdge e : spt.pathTo(t)) // 研究从 s 到 t 的边 76 { 77 int i = e.from(), j = e.to() - n; 78 if (i < n) // 去掉与顶点 s 和 t 有关的部分和反向边 79 { 80 xy[i] = j; 81 yx[j] = i; 82 } 83 } 84 for (int i = 0; i < n; i++) // 垫起各顶点的距离 85 { 86 px[i] += spt.distTo(i); 87 py[i] += spt.distTo(n + i); 88 } 89 } 90 91 private double reducedCost(int i, int j) // 顶点对之间的修正权值,用原始权值减去全局最小权值,再加上起点、终点的距离差 92 { 93 double reducedCost = (weight[i][j] - minWeight) + px[i] - py[j]; 94 assert reducedCost >= 0.0; 95 if (Math.abs(reducedCost) <= FLOATING_POINT_EPSILON * (Math.abs(weight[i][j]) + Math.abs(px[i]) + Math.abs(py[j]))) 96 return 0.0; 97 return reducedCost; 98 } 99 100 public double dualRow(int i) 101 { 102 return px[i]; 103 } 104 105 public double dualCol(int j) 106 { 107 return py[j]; 108 } 109 110 public int sol(int i) 111 { 112 return xy[i]; 113 } 114 115 public double weight() // 输出解的权值总和 116 { 117 double total = 0.0; 118 for (int i = 0; i < n; i++) 119 { 120 if (xy[i] != -1) 121 total += weight[i][xy[i]]; 122 } 123 return total; 124 } 125 126 private boolean isDualFeasibleAndComplementarySlack() // 检查对偶可行性和互补松弛性 127 { 128 for (int i = 0; i < n; i++) // 检查所有边的修正权值不小于 0 129 { 130 for (int j = 0; j < n; j++) 131 { 132 if (reducedCost(i, j) < 0) 133 { 134 StdOut.println("Dual variables are not feasible"); 135 return false; 136 } 137 } 138 } 139 for (int i = 0; i < n; i++) // 检查原变量和对偶变量的互补松弛性,即已匹配的顶点对修正权值为 0,未匹配的非 0 140 { 141 if (xy[i] != -1 && reducedCost(i, xy[i]) != 0) 142 { 143 StdOut.println("Primal and dual variables are not complementary slack"); 144 return false; 145 } 146 } 147 return true; 148 } 149 150 private boolean isPerfectMatching()// 检查是否为完美匹配 151 { 152 boolean[] perm = new boolean[n]; 153 for (int i = 0; i < n; i++) 154 { 155 if (perm[xy[i]])// ?perm[-1] 初始化为 true? 156 { 157 StdOut.println("Not a perfect matching"); 158 return false; 159 } 160 perm[xy[i]] = true; 161 } 162 for (int j = 0; j < n; j++)// 检查 xy[] 和 yx[] 对称性 163 { 164 if (xy[yx[j]] != j) 165 { 166 StdOut.println("xy[] and yx[] are not inverses"); 167 return false; 168 } 169 } 170 for (int i = 0; i < n; i++) 171 { 172 if (yx[xy[i]] != i) 173 { 174 StdOut.println("xy[] and yx[] are not inverses"); 175 return false; 176 } 177 } 178 return true; 179 } 180 181 public static void main(String[] args) 182 { 183 int n = Integer.parseInt(args[0]); 184 double[][] weight = new double[n][n]; 185 for (int i = 0; i < n; i++) 186 { 187 for (int j = 0; j < n; j++) 188 weight[i][j] = StdRandom.uniform(900) + 100; // 权值 100 ~ 999 189 } 190 191 class01 assignment = new class01(weight); 192 StdOut.printf("weight = %.0f ", assignment.weight()); 193 194 if (n >= 20) 195 return; 196 for (int i = 0; i < n; i++) 197 { 198 for (int j = 0; j < n; j++) 199 StdOut.printf("%c%.0f ", (j == assignment.sol(i)) ? '*' : ' ', weight[i][j]); 200 StdOut.println(); 201 } 202 } 203 }