斜率优化学会以后好像也不是那么难嘛。。。
以BZOJ1911为例
->在洛谷上查看
设(s_i)为前(i)个元素的前缀和,(f_i)为dp数组。
(f_i=max{f_j+a(s_i-s_j)^2+b(s_i-s_j)+c})
(f_i=max{f_j+a(s_i^2-2s_is_j+s_j^2)+b(s_i-s_j)+c})
(f_i=max{f_j+as_i^2-2as_is_j+as_j^2+bs_i-bs_j+c})
(=max{f_j-2as_js_i+as_j^2-bs_j}+as_i^2+bs_i+c)
将(max)内部变为直线解析式。
令(k=-2as_j),(b=f_j+as_j^2-bs_j),自变量(x)。
原式(=kx+b)
由于(-5le ale-1),(s_i)单调递增
所以(k=-2as_i>0)且单调递增,即斜率递增。
我们的目标是找到当(x=s_i)时,(y)最大的一条直线。
如图是目前加入队列的直线,绿色直线为们查找的位置。
紫色部分的边缘便是每个位置的最大值,是一个凸壳。
由于斜率是递增的,所以蓝色直线必然在队尾。
发现我们找到的最大值在红线上,由于红线的斜率大于蓝线的斜率,且(s_i)递增(即接下来访问的位置都在绿线的右边),所以实际上蓝线代表的j已经不可能更新任何之后的状态了。
于是我们把它弹掉,然后用目前的队尾(就是之前最大值所在的线)更新答案,再把更新好的线加入队列。
但是还需要考虑一种情况。
就是一条直线上没有任何一段在凸壳上,即被原凸壳上的任意两条直线覆盖。
比如我们又加入了一根绿线。
我们发现红线被蓝线和绿线覆盖了,而这条刚刚被覆盖的直线总是加入新的直线之前的队首,并且当它被队首之后的那条直线和新加入的线覆盖,它也被任意两条直线覆盖。
想想为什么。这里不给出证明。
(其实就是懒得想)
那么如何判断两条线覆盖另一条线?
假设(l_1:k_1x+b_1),(l_2:k_2x+b_2),(l_3:k_3x+b_3),并且(l_1),(l_2)覆盖(l_3),(color{red}k_1>k_2),(color{red}k_1>k_3)。
我们可以找到(l_1),(l_2)的交点,过这个交点作(l_3)的平行线(l_4),显然(l_4)与当前凸壳相切与该点,如果(l_4)在(l_3)之上说明(l_3)不在凸壳上。
其实就是(l_1),(l_2)交点在(l_3)之上。
所以只需要比较这个交点的(y)坐标和(l_3)在(x)坐标相同时的(y)值即可。
具体如下:
先求(l_1),(l_2)交点。
(k_1x+b_1=k_2x+b_2)
解得(x=frac{b_2-b_1}{k_1-k_2}),此时(k_1x+b_1=k_2x+b_2=k_1frac{b_2-b_1}{k_1-k_2}+b_1)
(l_3)在(x)相等时的(y=k_3frac{b_2-b_1}{k_1-k_2}+b_3)
如果交点在(l_3)之上,
(k_1frac{b_2-b_1}{k_1-k_2}+b_1ge k_3frac{b_2-b_1}{k_1-k_2}+b_3)
(frac{b_2-b_1}{k_1-k_2}+frac{b_1}{k_1}ge frac{k_3}{k_1}frac{b_2-b_1}{k_1-k_2}+frac{b_3}{k_1})
(frac{1}{k_3}frac{b_2-b_1}{k_1-k_2}+frac{b_1}{k_1k_3}ge frac{1}{k_1}frac{b_2-b_1}{k_1-k_2}+frac{b_3}{k_1k_3})
(frac{1}{k_3}frac{b_2-b_1}{k_1-k_2}-frac{1}{k_1}frac{b_2-b_1}{k_1-k_2}ge frac{b_3}{k_1k_3}-frac{b_1}{k_1k_3})
(frac{k_1-k_3}{k_1k_3}frac{b_2-b_1}{k_1-k_2}ge frac{b_3-b_1}{k_1k_3})
(frac{b_2-b_1}{k_1-k_2}ge frac{b_3-b_1}{k_1-k_3})
因为(k_1>k_2),(k_1>k_3)
((b_2-b_1)(k_1-k_3)ge (b_3-b_1)(k_1-k_2))
所以满足上式时,(l_3)不在凸壳中。
于是就可以写代码了。
code:
#include<bits/stdc++.h>
using namespace std;
int n,v[1000010],s[1000010],q[1000010],l,r;
long long a,b,c,dp[1000010],lk[1000010],lb[1000010],tv;
void scan(int &x){
x=0;
char c=getchar();
while('0'>c||c>'9')c=getchar();
while('0'<=c&&c<='9')x=x*10+c-'0',c=getchar();
}
long long val(int i,int x){
return lk[i]*x+lb[i];
}
bool cov(int l1,int l2,int l3){//l1,l2 cover l3
return (lb[l2]-lb[l1])*(lk[l1]-lk[l3])>=(lb[l3]-lb[l1])*(lk[l1]-lk[l2]);
}
int main(){
scanf("%d%lld%lld%lld",&n,&a,&b,&c);
for(int i=1;i<=n;i++){
scan(v[i]);
s[i]=s[i-1]+v[i];
}
l=r=1;
q[1]=0;
dp[0]=0;
for(int i=1;i<=n;i++){
while(l<r&&val(q[l],s[i])<=val(q[l+1],s[i]))l++;
dp[i]=val(q[l],s[i])+a*s[i]*s[i]+b*s[i]+c;
lk[i]=-2*a*s[i];
lb[i]=dp[i]+a*s[i]*s[i]-b*s[i];
while(l<r&&cov(i,q[r-1],q[r]))r--;//不能写cov(q[r-1],i,q[r])
q[++r]=i;
}
printf("%lld",dp[n]);
return 0;
}