• KM算法及其优化的学习笔记&&bzoj2539: [Ctsc2000]丘比特的烦恼


    感谢  http://www.cnblogs.com/vongang/archive/2012/04/28/2475731.html

    这篇blog里提供了3个链接……基本上很明白地把KM算法是啥讲清楚了

    然而n^4的KM好像并没有什么卵用啊……所以不得不学n^3的

    我看了一下各种,大部分blog里写的声称是n^3的KM,其实貌似都是n^4的(包括上面的链接以及上面链接里提供的链接)

    这是因为他们有个共同点

    他们虽然用slack数的优化组避免了暴力枚举d所消耗的时间,但由于一次增广是n^2的,所以拖慢了复杂度

    那么怎么解决这个问题呢?

    尛焱轟告诉我们,用bfs增广的KM是n^3的,用dfs增广的KM是n^4的

    尛焱轟还告诉我们,可以去UOJ上拉个板子,都是n^3的

    于是窝就拉了个策爷的板子来看(然后改了几个变量名,背下来,就学完了……)

    为什么dfs会成为算法时间复杂度减小的瓶颈呢?

    我们发现,每更新顶标,就要重新从当前点开始dfs找一遍增广路,有很多冗余的操作

    实际上,更新完顶标之后,交错树只会增加新的点

    那么窝萌不妨用bfs来增广,每次修改完顶标,把没访问到的右侧点的slack值也相应地减去d,那么slack值为0就相当于多了一条可行边,就相当于能够访问到新的节点,也就可以继续找增广路了

    这样再把新的点加进队列,就避免了dfs增广的版本中的冗余操作

    这样就发挥了slack这一优化的优势,复杂度自然降到O(n^3)

    然后窝来帖一下窝的代码(uoj#80 二分图最大权匹配)

    #include <iostream>
    #include <cstdio>
    #include <cmath>
    #include <cstring>
    #include <cstdlib>
    #include <algorithm>
    #define ll long long
    #define N 405
    #define INF (1LL<<60)
    
    using namespace std;
    inline int read(){
    	int ret=0;char ch=getchar();
    	while (ch<'0'||ch>'9') ch=getchar();
    	while ('0'<=ch&&ch<='9'){
    		ret=ret*10-48+ch;
    		ch=getchar();
    	}
    	return ret; 
    }
    
    int n,fx[N],fy[N],prev[N];
    ll g[N][N],A[N],B[N],slk[N];
    bool visx[N],visy[N];
    int q[N],qh,qt;
    
    void aug(int v){
    	if (!v) return;
    	fy[v]=prev[v];
    	aug(fx[prev[v]]);
    	fx[fy[v]]=v;
    }
    
    void bfs_KM(int _s){
    	memset(visx,0,sizeof(visx));
    	memset(visy,0,sizeof(visy));
    	memset(slk,127,sizeof(slk));
    	qh=qt=0;
    	q[++qt]=_s;
    	for (;;){
    		while (qh<qt){
    			int u=q[++qh];
    			visx[u]=1;
    			for (int v=1;v<=n;++v)if (!visy[v]){
    				if (A[u]+B[v]==g[u][v]){
    					visy[v]=1;
    					prev[v]=u;
    					if (!fy[v]){aug(v);return;}
    					q[++qt]=fy[v];
    					continue;
    				}
    				if (slk[v]>A[u]+B[v]-g[u][v]){
    					slk[v]=A[u]+B[v]-g[u][v];
    					prev[v]=u;
    				}
    			}
    		}
    		ll d=INF;
    		for (int i=1;i<=n;++i)
    			if (!visy[i]) d=min(d,slk[i]);
    		for (int i=1;i<=n;++i){
    			if (visx[i]) A[i]-=d;
    			if (visy[i]) B[i]+=d;
    			else slk[i]-=d;
    		}
    		for (int v=1;v<=n;++v)if (!visy[v]&&!slk[v]){
    			visy[v]=1;
    			if (!fy[v]){aug(v);return;}
    			q[++qt]=fy[v];
    		}
    	}
    }
    
    ll KM(){
    	for (int i=1;i<=n;++i){
    		A[i]=-INF;B[i]=0;
    		for (int j=1;j<=n;++j)
    			A[i]=max(A[i],g[i][j]);
    	}
    	memset(fx,0,sizeof(fx));
    	memset(fy,0,sizeof(fy));
    	for (int i=1;i<=n;++i) bfs_KM(i);
    	ll ret=0;
    	for (int i=1;i<=n;++i) ret+=A[i]+B[i];
    	return ret;
    }
    
    bool e0[N][N];
    int main(){
    	int nl=read(),nr=read();
    	memset(g,0,sizeof(g));
    	memset(e0,0,sizeof(e0));
    	for (int m0=read();m0;--m0){
    		int u=read(),v=read();
    		g[u][v]=read();
    		e0[u][v]=1;
    	}
    	n=max(nl,nr);
    	ll ans=KM();
    	printf("%lld
    ",ans);
    	for (int i=1;i<=nl;++i)
    		printf("%d ",e0[i][fx[i]]*fx[i]);
    	puts("");
    	return 0;
    }
    

    感谢尛焱轟神犇的指点

    感谢jcvb神犇的代码

    感谢上面的那篇blog以及那篇blog里的链接

    更新一下,窝在丘比特的烦恼(KM模板题)里把KM封装了一下,方便大(我)家(拖)学(板)习(子)

    顺便提一下此题的几个坑爹的地方:1坐标可能为负,2姓名无视大小写,3必须连n对情侣(也就是说不能连的边权必须赋为-INF)

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <cstdlib>
    #include <algorithm>
    #include <map>
    #include <string>
    #define N 32
    #define INF (1e9)
    
    using namespace std;
    inline int read(){
    	int ret=0;char ch=getchar();
    	bool flag=0;
    	while (ch<'0'||ch>'9'){
    		flag=ch=='-';
    		ch=getchar();
    	}
    	while ('0'<=ch&&ch<='9'){
    		ret=ret*10-48+ch;
    		ch=getchar();
    	}
    	return flag?-ret:ret;
    }
    
    struct KM{
    	int n;
    	int g[N][N],slk[N],A[N],B[N];
    	int prev[N],fx[N],fy[N];
    	bool visx[N],visy[N];
    	int q[N],qh,qt;
    	void clear(int _n){
    		n=_n;memset(g,0,sizeof(g));
    	}
    	void AddEdge(int u,int v,int w){
    		g[u][v]=w;
    	}
    	void aug(int v){
    		if (!v) return;
    		fy[v]=prev[v];
    		aug(fx[fy[v]]);
    		fx[fy[v]]=v;
    	}
    	void bfs(int _s){
    		memset(visx,0,sizeof(visx));
    		memset(visy,0,sizeof(visy));
    		memset(slk,127,sizeof(slk));
    		qh=qt=0;q[++qt]=_s;
    		for (;;){
    			while (qh<qt){
    				int u=q[++qh];
    				visx[u]=1;
    				for (int v=1;v<=n;++v)if (!visy[v]){
    					if (A[u]+B[v]==g[u][v]){
    						visy[v]=1;
    						prev[v]=u;
    						if (!fy[v]){aug(v);return;}
    						q[++qt]=fy[v];
    					}
    					else if (slk[v]>A[u]+B[v]-g[u][v]){
    						slk[v]=A[u]+B[v]-g[u][v];
    						prev[v]=u;
    					}
    				}
    			}
    			int d=INF;
    			for (int i=1;i<=n;++i)if (!visy[i])d=min(d,slk[i]);
    			for (int i=1;i<=n;++i){
    				if (visx[i]) A[i]-=d;
    				if (visy[i]) B[i]+=d;
    				else slk[i]-=d;
    			}
    			for (int v=1;v<=n;++v)
    				if (!visy[v]&&!slk[v]){
    					visy[v]=1;
    					if (!fy[v]){aug(v);return;}
    					q[++qt]=fy[v];
    				}
    		}
    	}
    	int solve(){
    		memset(A,128,sizeof(A));
    		memset(B,0,sizeof(B));
    		memset(fx,0,sizeof(fx));
    		memset(fy,0,sizeof(fy));
    		for (int i=1;i<=n;++i)
    			for (int j=1;j<=n;++j) A[i]=max(A[i],g[i][j]);
    		for (int i=1;i<=n;++i) bfs(i);
    		int ret=0;
    		for (int i=1;i<=n;++i) ret+=A[i]+B[i];
    		return ret;
    	}
    } km;
    
    
    int n;
    map<string,int> id;
    int x[N*2],y[N*2],lmt;
    void Upper(string &s){
    	int l=s.length();
    	for (int i=0;i<l;++i)if (s[i]>'Z') s[i]-=32;
    }
    
    int main(){
    	string tmp;
    	lmt=read();n=read();
    	for (int i=1;i<=2*n;++i){
    		x[i]=read();y[i]=read();
    		cin>>tmp;Upper(tmp);id[tmp]=i;
    	}
    	km.clear(n);
    	for (int i=1;i<=n;++i)
    		for (int j=1;j<=n;++j)
    			km.AddEdge(i,j,1);
    	for (cin>>tmp;tmp!="End";cin>>tmp){
    		Upper(tmp);
    		int u=id[tmp],v;
    		cin>>tmp;Upper(tmp);v=id[tmp];
    		if (u>v) swap(u,v);
    		km.AddEdge(u,v-n,read());
    	}
    	for (int i=1;i<=n;++i)
    		for (int j=n+1;j<=2*n;++j){
    			bool found=(x[i]-x[j])*(x[i]-x[j])+(y[i]-y[j])*(y[i]-y[j])>lmt*lmt;
    			for (int k=1;k<=2*n&&!found;++k){
    				int A=y[k]-y[j],B=x[k]-x[j],C=y[k]-y[i],D=x[k]-x[i];
    				found=A*D==B*C&&(A*C<0||B*D<0);
    			}
    			if (found) km.AddEdge(i,j-n,-1e7);
    		}
    	printf("%d
    ",km.solve());
    	return 0;
    }
    
  • 相关阅读:
    mysql批量替换指定字符串
    php中英字符串截取
    比较两个JSON字符串是否完全相等
    Java FastJson 介绍
    线程池
    DBUS及常用接口介绍
    在Mac中如何正确地设置JAVA_HOME
    base64 原理
    sizeof与strlen的区别
    Kubernetes 部署失败的 10 个最普遍原因
  • 原文地址:https://www.cnblogs.com/wangyurzee7/p/5215231.html
Copyright © 2020-2023  润新知