题解
又写了一遍KM算法,这题刚好是把最大最小KM拼在一起写的,感觉比较有记录价值
感觉KM始终不熟啊QAQ
算法流程大抵如下,原理就是每次我们通过减少最少的匹配量达成最大匹配,所以获得的一定是最大价值
1.我们先给左部点求一个期望大小,如果是最大KM,期望大小就是最大的那条边的权值,如果是最小KM,期望大小就是最小的那条边的权值
2.然后跑二分图匹配,两个点能匹配的条件是左点(u)的期望值加右点(v)的期望值刚好是边权
3.给无法访问的点更新断层大小,如果是最小匹配,那么断层就是(c[u][v] - (ex_l[u] + ex_r[v])),如果是最大匹配,就是((ex_l[u] + ex_r[v]) - c[u][v])
4.在没有被访问的右点里寻找最小的减少量(d)
5.给访问过的左点和右点,如果是最小匹配,左点加上(d),因为要包括进一些更大的边,右点减去(d),如果是最大匹配,左点减去(d),因为要包括进一些更小的边,右点加上(d)
代码
#include <bits/stdc++.h>
#define fi first
#define se second
#define pii pair<int, int>
#define pdi pair<db, int>
#define mp make_pair
#define pb push_back
#define enter putchar('
')
#define space putchar(' ')
#define eps 1e-8
#define mo 974711
#define MAXN 205
//#define ivorysi
using namespace std;
typedef long long int64;
typedef double db;
template <class T>
void read(T &res) {
res = 0;
char c = getchar();
T f = 1;
while (c < '0' || c > '9') {
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
res = res * 10 + c - '0';
c = getchar();
}
res *= f;
}
template <class T>
void out(T x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x >= 10) {
out(x / 10);
}
putchar('0' + x % 10);
}
int N;
int c[105][105];
int ex_l[105], ex_r[105], slack[105], matc[105];
bool vis_l[105], vis_r[105];
bool match(int u, int on) {
vis_l[u] = 1;
for (int v = 1; v <= N; ++v) {
if (!vis_r[v]) {
if (vis_r[v])
continue;
int gap;
if (on == 0)
gap = c[u][v] - ex_l[u] - ex_r[v];
else
gap = ex_l[u] + ex_r[v] - c[u][v];
if (gap == 0) {
vis_r[v] = 1;
if (!matc[v] || match(matc[v], on)) {
matc[v] = u;
return true;
}
} else
slack[v] = min(slack[v], gap);
}
}
return false;
}
int KM(int on) {
for (int i = 1; i <= N; ++i) {
ex_r[i] = 0;
if (on == 0)
ex_l[i] = 0x7fffffff;
else
ex_l[i] = 0;
for (int j = 1; j <= N; ++j) {
if (on == 0)
ex_l[i] = min(ex_l[i], c[i][j]);
else
ex_l[i] = max(ex_l[i], c[i][j]);
}
}
memset(matc, 0, sizeof(matc));
for (int i = 1; i <= N; ++i) {
for (int j = 1; j <= N; ++j) slack[j] = 0x7fffffff;
while (1) {
memset(vis_l, 0, sizeof(vis_l));
memset(vis_r, 0, sizeof(vis_r));
if (match(i, on))
break;
int d = 0x7fffffff;
for (int j = 1; j <= N; ++j) {
if (!vis_r[j])
d = min(d, slack[j]);
}
for (int j = 1; j <= N; ++j) {
if (on == 0) {
if (vis_l[j])
ex_l[j] += d;
if (vis_r[j])
ex_r[j] -= d;
else
slack[j] -= d;
} else {
if (vis_l[j])
ex_l[j] -= d;
if (vis_r[j])
ex_r[j] += d;
else
slack[j] -= d;
}
}
}
}
int res = 0;
for (int v = 1; v <= N; ++v) {
res += c[matc[v]][v];
}
return res;
}
void Solve() {
read(N);
for (int i = 1; i <= N; ++i) {
for (int j = 1; j <= N; ++j) {
read(c[i][j]);
}
}
out(KM(0));
enter;
out(KM(1));
enter;
}
int main() {
#ifdef ivorysi
freopen("f1.in", "r", stdin);
#endif
Solve();
return 0;
}