• wqs 二分学习笔记


    又称为带权二分

    一种优化凸函数 dp 的方式,明显的标志是选 k 个。

    一般这种玩意都是可以强套一个 wqs 二分上去,消一个 O(n) 加一个 (O(log)),而且还是从状态数上消一个。

    我们从 LCT 这道题来引入。

    首先题目要求选 k+1 条不相交链的权值和最大。

    设出 (dp[i][j][0/1/2]) 表示以 (i) 为根的子树在图上的度数为 (0,1,2),他的子树中含有 (j) 条链。并且当度数为 (1) 的时候这条链不计入 (j) 中。

    分类讨论的有点繁琐就不写了,看看代码吧(讲真,这个 dp 还挺神仙的

    我觉得 xtq 树形 dp 的方式还挺用的

    inline void dfs(int x,int ff)
    {
    	dp[0][x][0] = dp[1][x][0] = dp[2][x][1] = dp[3][x][0] = 0;
    	for(int i=head[x],v;i;i=nxt[i])
    	{
    		v = ver[i];
    		if(v == ff) continue;
    		dfs(v , x);
    		// for(int u = 0;u <= 2;u++)
    		// 	for(int j=0;j <= k;j++) aux[u][j] = -INF;
    		memset(aux[0],0xcf,sizeof(aux[0])),memset(aux[1],0xcf,sizeof(aux[1]));
    		memset(aux[2],0xcf,sizeof(aux[2]));
    		for(int u = 0;u <= k;u++)
    			for(int q = 0;q + u <= k;q ++)
    				aux[0][u + q] = max(aux[0][u + q],dp[0][x][u] + dp[3][v][q]);
    		for(int u = 0;u <= k;u ++)
    			for(int q = 0;q + u <= k;q ++)
    				aux[1][u + q] = max(aux[1][u + q],max(dp[1][x][u] + dp[3][v][q],dp[0][x][u] + dp[1][v][q] + Edge[i]));
    		for(int u = 0;u <= k ; u++)
    			for(int q = 0;q + u<= k;q ++)
    			{
    				aux[2][u + q] = max(aux[2][u + q],dp[2][x][u] + dp[3][v][q]);
    				// aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[1][x][u] + dp[1][v][q] + Edge[i]);
    			}
    
    		for(int u = 0;u <= k ; u++)
    			for(int q = 0;q + u + 1<= k;q ++)
    			{
    				// aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[2][x][u] + dp[3][v][q]);
    				aux[2][u + q + 1] = max(aux[2][u + q + 1],dp[1][x][u] + dp[1][v][q] + Edge[i]);
    			}
    			// printf("%d : 
    ",x);
    		for(int u =0 ;u <=2;u++)
    			for(int q = 0;q <= k;q ++ )
    			{
    				// printf("aux[%d][%d] = %d
    ",u,q,aux[u][q]);
    				dp[u][x][q] = aux[u][q];
    			} 
    		dp[1][x][0] = max(dp[1][x][0] , dp[1][v][0] + Edge[i]);
    	}
    	// dp[0][x][1] = max(0,dp[0][x][1]);
    	for(int i=1;i<=k;i++)
    	dp[3][x][i] = max(dp[0][x][i],max(dp[1][x][i - 1],dp[2][x][i]));
    	// for(int i=1;i<=)
    }
    

    我们现在的 (dp)(O(nk^2)) 的。这个复杂度大的离谱。

    这时候请出我们的带权二分来。这里默认我们的 dp 函数是一个凸函数

    我们将原来的 (dp)(j) 那维限制去掉,这样就可以将复杂度降到 (O(n))。但是这样不能保证我们恰好选了 k+1 条链,所以我们要对原函数做一些魔改。

    设原函数为 (ans(x)),当 (ans'(x)=0) 时,ans(x) 取得最大值。我们不加修改的 dp 求出来的就是这个东西。

    现在我们设一个新函数 (g(x) = ans(x) +val imes x),这个函数一阶导为减函数,二阶导为一个上凸函数。所以我们可以通过调节 val (就是斜率)来调节 (g'(x)) 的零点,这样就能调节出当 (g(x)) 取得最值得时候(ans(x)) 恰好取得 (k+1) 条链,这样皆大欢喜。

    关于恰好选 k 个是一个凸函数,你要想如果恰好选 (1) 个,那我们肯定选择最大的那个,选两个我会把次大的选上,这样每次的增量都不如上一个大,就会形成一个凸函数。更仔细一点,如果有那么一点点限制,要求恰好选 (k) 个,那我后面选的东西可能会影响到前面选的,并且这时候还要求数量达到我们要求的,就被迫舍弃权值最优,来追求数量,这就导致了凸函数的后半段产生。

    复杂度 (O(blog k))

    关于 wqs 的实际操作来说,有一点点细节。

    对最大值来说:你考虑二分一个惩罚值,当你选的少了的时候,我们想让它下次选得再多一点,就会把惩罚值下调,反之就会上调。

    对最小值来说:选得少了的时候,我们想让它下次再多选一点,惩罚值就会下调,反之上调。

    对于多点共线的情况,我们优先选物品最少的或者最多的,二分的时候只要物品在 k 的我们指定的一侧时就去更新答案。

    P4383 [八省联考2018]林克卡特树

    #include<bits/stdc++.h>
    
    using namespace std;
    
    #define int long long
    #define pii pair<int,int>
    
    template<typename _T>
    inline void read(_T &x)
    {
    	x=0;char s=getchar();int f=1;
    	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
    	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
    	x*=f;
    }
    const int np = 3e5 + 5;
    int head[np],ver[np * 2],nxt[np * 2],Edge[np * 2];
    int tit;
    inline void add(int x,int y,int w)
    {
    	ver[++tit] = y;
    	Edge[tit] = w;
    	nxt[tit] = head[x];
    	head[x] = tit;
    }
    
    struct qwq
    {
    	int f,fanga;
    
    	friend qwq operator+(qwq a,qwq b)
    	{
    		return (qwq){a.f + b.f , a.fanga + b.fanga};
    	}
    
    	inline friend  qwq Max(qwq a,qwq b)
    	{
    		if(a.f == b.f)
    		{
    			if(a.fanga > b.fanga) return a;
    			else return b;
    		} 
    		if(a.f > b.f) return a;
    		else return b;
    	}
    
    }dp[5][np];
    // dp[d][i] 表示前以 i 为根的树,度数为 d 的
    int n,k,sakura;
    
    inline void dfs(int x,int ff)
    {
    	dp[0][x] = (qwq){0,0},dp[1][x] = (qwq){0,0},dp[2][x] = (qwq){sakura,1};
    	for(int i=head[x],v;i;i=nxt[i])
    	{
    		v = ver[i];
    		if(v == ff) continue;
    		dfs(v,x);
    		dp[2][x] = Max(dp[2][x] + dp[3][v],dp[1][x] + dp[1][v] + (qwq){Edge[i] + sakura,1});
    		dp[1][x] = Max(dp[1][x] + dp[3][v],dp[0][x] + dp[1][v] + (qwq){Edge[i],0});
    		dp[0][x] = dp[0][x] + dp[3][v];
    	}
    	dp[3][x] = Max(dp[0][x],Max(dp[1][x] + (qwq){sakura,1},dp[2][x]));
    }
    
    inline void judging(int x)
    {
    	sakura = x;
    	dfs(1,0);
    }
    
    signed main()
    {
    	read(n),read(k);
    	k++;
    	for(int i=1,a,b,w;i<=n - 1;i ++ )
    	{
    		read(a),read(b),read(w);
    		add(a,b,w);
    		add(b,a,w);
    	}
    	int l = -1e8,r = 1e8,Ans=0;
    	while(l <= r)
    	{
    		int mid = l + r >> 1;
    		judging(mid);
    		if(dp[3][1].fanga >= k)
    		{
    			Ans = dp[3][1].f - k * mid;
    			// printf("%lld %lld
    ",dp[3][1].fanga,Ans);
    			r = mid - 1;
    		}
    		else l = mid + 1;
    	}
    	printf("%lld",Ans);
    }
    

    P6246 [IOI2000] 邮局 加强版

    #include<bits/stdc++.h>
    
    using namespace std;
    
    #define int long long
    #define pii pair<int,int>
    
    template<typename _T>
    inline void read(_T &x)
    {
    	x=0;char s=getchar();int f=1;
    	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
    	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
    	x*=f;
    }
    const int np = 1e6 + 5;
    int a[np];
    int sum[np],n,k;
    
    inline int Abs(int x)
    {
    	return x < 0?-x:x;
    }
    
    inline int calc(int l,int r)
    {
    	int op = l + r >> 1;
    	return a[op] * (op - l + 1) - sum[op]+sum[l - 1] + Abs(a[op] * (r-op+1) - sum[r] + sum[op - 1]);
    }
    
    struct qwq
    {
    	int f,fanga;
    	friend qwq operator+(qwq a,qwq b)
    	{
    		return (qwq){a.f + b.f,a.fanga + b.fanga};
    	}
    	friend bool operator<(qwq a,qwq b)
    	{
    		if(a.f == b.f) return a.fanga < b.fanga;
    		else return a.f < b.f;
    	}
    }dp[np];
    
    int sakura;
    // int l_[2333],r_[2333],juec[2333];
    struct qaq
    {
    	int l_,r_,juec;
    	// int nx
    }que[np * 2];
    int top = 0;
    
    inline int binary(qaq u,int op)
    {
    	int l = u.l_,r = u.r_,opt = u.juec,ans =  u.r_ + 1;//l <= op?op:0;
    	while(l <= r)
    	{
    		int mid = l + r >> 1;
    		if(dp[op] + (qwq){calc(op + 1,mid) + sakura,1} < dp[opt] + (qwq){calc(opt + 1,mid) + sakura,1}) ans = mid,r = mid - 1;
    		else l = mid + 1;
    	}
    	return ans;
    }
    
    inline void solve()
    {
    	int head = 1,tail = 1;
    	dp[0] = (qwq){0,0};
    	que[head] = (qaq){1,n,0};
    	for(int i=1;i<=n;i++)
    	{
    		while(head < tail && que[head].r_ < i) head++;
    		int j = que[head].juec;
    		dp[i] = dp[j] + (qwq){calc(j + 1,i) + sakura,1};// + sakura;
    		int spilt = 0;
    		while(head < tail && que[tail].l_ == binary(que[tail],i)) spilt = que[tail].l_,tail--;
    		spilt = binary(que[tail],i);
    		if(spilt)
    		{
    			que[tail].r_ = spilt - 1;
    	//		printf("%lld ",spilt);
    			que[++tail] = (qaq){spilt,n,i};			
    		}
    	}
    	// if(sakura == 0)
    //	for(int i=1;i<=n;i++)
    //	{
    //		printf("%lld ",dp[i].f);
    //	}
    //	printf("
    ");
    }
    
    namespace subtask{
    	int fp[500][4333];
    	inline int bbinary(int c,qaq u,int op)
    	{
    		int l = u.l_,r = u.r_,opt = u.juec,ans =  u.r_ + 1;//l <= op?op:0;
    		while(l <= r)
    		{
    			int mid = l + r >> 1;
    			if(fp[c - 1][op] + calc(op + 1,mid) < fp[c - 1][opt] + calc(opt + 1,mid)) ans = mid,r = mid - 1;
    			else l = mid + 1;
    		}
    		return ans;
    	}
    	inline void solve1(int c)
    	{
    		int head = 1,tail = 1;
    		fp[c][0] = 0;
    		que[head] = (qaq){1,n,0};
    		for(int i=1;i<=n;i++)
    		{
    			while(head < tail && que[head].r_ < i) head++;
    			int j = que[head].juec;
    			fp[c][i] = fp[c - 1][j] + calc(j + 1,i);// + sakura;
    			int spilt = 0;
    			while(head < tail && que[tail].l_ == bbinary(c,que[tail],i)) spilt = que[tail].l_,tail--;
    //			int spilt = 0;
    			spilt = bbinary(c,que[tail],i);
    			if(spilt)
    			{
    				que[tail].r_ = spilt - 1;
    		//		printf("%lld ",spilt);
    				que[++tail] = (qaq){spilt,n,i};			
    			}
    		}
    	}
    
    	inline void Main()
    	{
    		memset(fp,0x3f,sizeof(fp));
    		fp[1][0] = 0;
    		for(int i=1;i<=n;i++) fp[1][i] = calc(1,i);
    		for(int i=2;i <= k;i++)
    		{
    			solve1(i);
    //			for(int j=1;j<=n;j++)
    //			printf("%lld ",fp[i][j]);
    //			printf("
    ");
    		}
    		printf("%lld",fp[k][n]);
    	}
    }
    
    inline void judging(int x)
    {
    	sakura = x;
    	solve();
    }
    
    signed main()
    {
    	read(n),read(k);
    	for(int i=1;i<=n;i++)
    	{
    		read(a[i]);
    		sum[i] = sum[i - 1] + a[i];
    	}	
    	int l = -1e8,r = 1e8,Ans(0);
    	 while(l <= r)
    	 {
    	 	int mid = l + r >> 1;
    	 	judging(mid);
    //	 	printf("%lld %lld
    ",dp[n].f,dp[n].fanga);
    	 	if(dp[n].fanga <= k)
    	 	{
    	 		Ans = dp[n].f - k * sakura;
    //			printf("%lld
    ",Ans);
    	 		r = mid - 1;
    	 	}
    	 	else l = mid + 1;
    	 }
    	printf("%lld",Ans);
    }
    

    据说这个东西还能用来优化个模拟费用流啥的,笑死,根本不会费用流。

  • 相关阅读:
    12个JavaScript MVC框架评估 简单
    chrome developer tool 调试技巧 简单
    转CSS3线性渐变 简单
    base64:URL背景图片与web页面性能优化 简单
    转linux下apache安装gzip压缩 简单
    转思考什么时候使用Canvas 和SVG 简单
    转周报的逻辑 简单
    浏览器三种刷新方式采取的不同缓存机制 简单
    poj 1308 Is It A Tree? (并查集)
    poj 2912 Rochambeau (并查集+枚举)
  • 原文地址:https://www.cnblogs.com/-Iris-/p/15340311.html
Copyright © 2020-2023  润新知