斜率优化DP总结
一般来讲,斜率优化DP的状态转移方程式为:
其中,(val(i,j))包含(i)和(j)的乘积项。
对于此类问题,我们一般会通过将该问题转化为维护凸包,通过规划进一步思考优化。
事实上,我们更倾向于将方程变形,通过观察每一位优化。
例题:任务安排
在总结斜率优化DP之前,还是要说一句,状态转移方程是才最关键的,状态转移方程想不出来,就完蛋。
这道题首先我们考虑定义(dp[i])代表前(i)台机器划分最小代价。
由于我们不知道前面最优解划分了多少批任务,因此,我们采用费用提前计算的思想进行解题。
那么,我们有:
可以看出,在方程最后,我们直接将(j)之后的所有的(C[k])在此处进行累加。因为如果后面的状态转移到了该状态,那么后面的费用一定包含该次断点的(s)乘以后面的(C[k])。不如我们提前计算好了,方便转移,避免状态冗长。
接下来,我们进行斜率优化的变形:
-
不妨设最优解在(j)处取得,那么有
[dp[i]=dp[j]+sumT[i]*(sumC[i]-sumC[j])+s*(sumC[n]-sumC[j]) ] -
变形得:
[dp[j]=(sumT[i]-s)*sumC[j]+dp[i]-sumT[i]*sumC[i]-s*sumC[n] ]
注意到,我们将(dp[j])看作(y),将(sumC[j])看成(x),那么我们要做的是最小化截距(dp[i]-sumT[i]*sumC[i]-s*sumC[n])。
事实上,上面那个变形并不是唯一的。一般而言,将乘积项转化为(kx),把没有任何系数的项看成(y),其它项看成截距。
按照刚刚的表达式来讲,我们可以维护一个队列,队列中包括每个“候选点”,用当前斜率(sumT[i]-s)进行查找,找到截距最小。接着,在更新完刚刚的值之后,我们将新产生的值加入队列中,并维护凸包。
至于为什么要维护凸包,我们给出以下证明:
观察这张,中间那个点永远不可能是最优点,因此,维护凸包,将这样的点干掉。
斜率优化细节还是蛮多的,因此一定要注意几点:
- 更新队列时要保证队列中至少有两个元素;
- 当比较斜率大小时,一般使用乘法比较,此时不等号是否需要收积数符号的影响。
另外这道题的特殊性在于:每一次的斜率单调递增,因而采取删除队首元素即可。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define pii pair <int, int>
#define mp(x, y) make_pair(x, y)
#define FOR(i, a, b) for(register int i = a; i <= b; ++ i)
#define ROF(i, a, b) for(register int i = a; i <= b; -- i)
using namespace std;
const int N = 300000 + 5;
typedef long long LL;
int n, s, q[N];
LL T[N], C[N], sumT[N], sumC[N], dp[N];
int main()
{
scanf("%d %d", &n, &s);
memset(sumT, 0, sizeof(sumT));
memset(sumC, 0, sizeof(sumC));
FOR(i, 1, n)
{
scanf("%lld %lld", &T[i], &C[i]);
sumT[i] = sumT[i - 1] + T[i], sumC[i] = sumC[i - 1] + C[i];
}
memset(dp, 0x3f, sizeof(dp));
dp[0] = 0;
int head = 1, tail = 1;
q[tail] = 0;
FOR(i, 1, n)
{
while(head < tail && (dp[q[head + 1]] - dp[q[head]]) < (sumT[i] + s) * (sumC[q[head + 1]] - sumC[q[head]])) ++ head;
dp[i] = dp[q[head]] + sumT[i] * (sumC[i] - sumC[q[head]]) + s * (sumC[n] - sumC[q[head]]);
while(head < tail && (dp[i] - dp[q[tail]]) * (sumC[q[tail]] - sumC[q[tail - 1]]) < (dp[q[tail]] - dp[q[tail - 1]]) * (sumC[i] - sumC[q[tail]])) -- tail;
q[++ tail] = i;
}
printf("%lld
", dp[n]);
return 0;
}
如果这道题斜率不单调递增,那么我们就二分最优值。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define CLR(a, x) memset(a, x, sizeof(a))
#define FOR(i, x, y) for(register int i = x; i <= y; ++ i)
#define ROF(i, x, y) for(register int i = x; i <= y; -- i)
#define pii pair <int, int>
#define mp(x, y) make_pair(x, y)
using namespace std;
const int N = 3e5 + 5;
typedef long long LL;
int n, s, q[N];
LL T[N], C[N], sumT[N], sumC[N], dp[N];
int main()
{
scanf("%d %d", &n, &s);
CLR(sumT, 0);
CLR(sumC, 0);
FOR(i, 1, n)
{
scanf("%lld %lld", &T[i], &C[i]);
sumT[i] = sumT[i - 1] + T[i], sumC[i] = sumC[i - 1] + C[i];
}
CLR(dp, 0x3f);
int head = 1, tail = 1, L, R, mid;
dp[0] = 0;
q[head] = 0;
FOR(i, 1, n)
{
L = head, R = tail;
while(L < R)
{
mid = L + ((R - L) >> 1);
if(dp[q[mid + 1]] - dp[q[mid]]< (sumT[i] + s) * (sumC[q[mid + 1]] - sumC[q[mid]])) L = mid + 1;
else R = mid;
}
dp[i] = dp[q[L]] + sumT[i] * (sumC[i] - sumC[q[L]]) + s * (sumC[n] - sumC[q[L]]);
while(head < tail && (dp[i] - dp[q[tail]]) * (sumC[q[tail]] - sumC[q[tail - 1]]) < (dp[q[tail]] - dp[q[tail - 1]]) * (sumC[i] - sumC[q[tail]])) -- tail;
q[++ tail] = i;
}
printf("%lld
", dp[n]);
return 0;
}
例题:运输小猫
考虑用每个小猫时间减去到(1)号点的距离,sort,然后这道题跟上一道题类似,并且按照上一道题的做法,这道题斜率优化部分稍微容易。
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<deque>
#define FOR(i, x, y) for(register int i = x; i <= y; ++ i)
#define ROF(i, x, y) for(register int i = x; i >= y; -- i)
#define CLR(a, b) memset(a, b, sizeof(a))
#define pii pair <int, int>
#define mp(x, y) make_pair(x, y)
using namespace std;
const int N = 1e5 + 5, P = 100 + 10;
typedef long long LL;
int n, m, p, D[N] = {}, q[N];
LL A[N], s[N], dp[P][N];
int main()
{
CLR(s, 0);
scanf("%d %d %d", &n, &m, &p);
FOR(i, 2, n)
{
scanf("%d", &D[i]);
D[i] += D[i - 1];
}
FOR(i, 1, m)
{
LL H, T;
scanf("%d %lld", &H, &T);
A[i] = T - D[H];
}
int head = 1, tail = 1;
sort(A + 1, A + m + 1);
FOR(i, 1, m) s[i] = s[i - 1] + A[i];
CLR(dp, 0x3f);
FOR(i, 0, p - 1) dp[i][0] = 0;
FOR(i, 1, p)
{
head = tail = 1;
q[tail] = 0;
FOR(j, 1, m)
{
while(head < tail && (dp[i - 1][q[head + 1]] + s[q[head + 1]] - dp[i - 1][q[head]] - s[q[head]]) <= A[j] * (q[head + 1] - q[head])) ++ head;
dp[i][j] = dp[i - 1][q[head]] + A[j] * (j - q[head]) - (s[j] - s[q[head]]);
while(head < tail && (dp[i - 1][q[tail]] + s[q[tail]] - dp[i - 1][q[tail - 1]] - s[q[tail - 1]]) * (j - q[tail]) > (dp[i - 1][j] + s[j] - dp[i - 1][q[tail]] - s[q[tail]]) * (q[tail] - q[tail - 1])) -- tail;
q[++ tail] = j;
}
}
printf("%lld
", dp[p][m]);
return 0;
}
例题:K匿名序列
这道题也一样,就是维护的时候一定要小心即可。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define FOR(i, x, y) for(register int i = x; i <= y; ++ i)
#define ROF(i, x, y) for(register int i = x; i >= y; -- i)
#define CLR(x, y) memset(x, y, sizeof(x))
using namespace std;
const int N = 500000 + 7;
typedef long long LL;
int n, k, a[N], q[N];
LL b[N], s[N], dp[N];
LL d_x(int x, int y)
{
return a[x + 1] - a[y + 1];
}
LL d_y(int x, int y)
{
return (dp[x] - dp[y]) - (s[x] - s[y]) + (b[x] - b[y]);
}
void prework()
{
CLR(a, 0), CLR(b, 0);
CLR(s, 0), CLR(q, 0);
a[n + 1] = 1 << 30;
return;
}
int main()
{
int T;
scanf("%d", &T);
while(T --)
{
scanf("%d %d", &n, &k);
prework();
FOR(i, 1, n)
{
scanf("%d", &a[i]);
s[i] = s[i - 1] + a[i];
}
FOR(i, 1, n) b[i] = 1ll * a[i + 1] * i;
int head = 1, tail = 1;
CLR(dp, 0x3f);
q[tail] = 0;
dp[0] = 0;
FOR(i, k, n)
{
while(head < tail && d_y(q[head + 1], q[head]) <= i * d_x(q[head + 1], q[head])) ++ head;
dp[i] = dp[q[head]] + s[i] - s[q[head]] - a[q[head] + 1] * (i - q[head]);
if(i + 1 >= k * 2)
{
while(head < tail && d_y(q[tail], q[tail - 1]) * d_x(i - k + 1, q[tail]) >= d_y(i - k + 1, q[tail]) * d_x(q[tail], q[tail - 1])) -- tail;
q[++ tail] = i - k + 1;
}
}
printf("%lld
", dp[n]);
}
return 0;
}
例题:玩具装箱
这道题也一样,首先(L)加上(1)。
然后就跟上一题没什么区别了。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define CLR(x, y) memset(x, y, sizeof(x))
#define FOR(i, x, y) for(register int i = x; i <= y; ++ i)
#define ROF(i, x, y) for(register int i = x; i >= y; -- i)
using namespace std;
typedef long long LL;
const int N = 50000 + 5;
int n, head, tail, q[N] = {};
LL L, C[N], dp[N], sumC[N], sumK[N];
LL d_x(int x, int y)
{
return sumK[x] - sumK[y];
}
LL d_y(int x, int y)
{
return (dp[x] - dp[y]) + (sumK[x] * sumK[x] - sumK[y] * sumK[y]);
}
int main()
{
scanf("%d %d", &n, &L);
++ L;
CLR(sumC, 0), CLR(sumK, 0);
FOR(i, 1, n)
{
scanf("%d", &C[i]);
sumC[i] = sumC[i - 1] + C[i];
sumK[i] = sumC[i] + i;
}
CLR(dp, 0x3f);
dp[0] = 0;
head = tail = 1;
q[tail] = 0;
FOR(i, 1, n)
{
while(head < tail && d_y(q[head + 1], q[head]) < 2 * (sumK[i] - L) * d_x(q[head + 1], q[head])) ++ head;
dp[i] = dp[q[head]] + (sumK[i] - sumK[q[head]] - L) * (sumK[i] - sumK[q[head]] - L);
while(head < tail && d_y(q[tail], q[tail - 1]) * d_x(i, q[tail]) > d_y(i, q[tail]) * d_x(q[tail], q[tail - 1])) -- tail;
q[++ tail] = i;
}
printf("%lld
", dp[n]);
return 0;
}
习题:特别行动队
这道题看起来特别难,实际上挺基本的。
移项得:
这次斜率(2*a*s[i])单调递减了,如出一辙。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define CLR(x,y) memset(x,y,sizeof(x))
#define FOR(i,x,y) for(register int i=x; i<=y; ++i)
#define ROF(i,x,y) for(register int i=x; i>=y; --i)
using namespace std;
const int N = 1000000 + 5;
typedef long long LL;
int n, q[N];
LL a, b, c, x[N], s[N], dp[N];
LL d_x(int p, int q)
{
return s[p] - s[q];
}
LL d_y(int p, int q)
{
return (dp[p] - dp[q]) + (a * s[p] * s[p] - a * s[q] * s[q]) + b * (s[q] - s[p]);
}
int main()
{
scanf("%d %lld %lld %lld", &n, &a, &b, &c);
CLR(q, 0), CLR(s, 0);
FOR(i, 1, n)
{
scanf("%lld", &x[i]);
s[i] = s[i - 1] + x[i];
}
CLR(dp, 0xcf);
dp[0] = 0;
int head, tail, j;
head = tail = 1;
q[tail] = 0;
FOR(i, 1, n)
{
while(head < tail && d_y(q[head + 1], q[head]) > 2 * a * s[i] * d_x(q[head + 1], q[head])) ++ head;
j = q[head];
dp[i] = dp[j] + a * (s[i] - s[j]) * (s[i] - s[j]) + b * (s[i] - s[j]) + c;
while(head < tail && d_y(q[tail], q[tail - 1]) * d_x(i, q[tail]) <= d_y(i, q[tail]) * d_x(q[tail], q[tail - 1])) -- tail;
q[++ tail] = i;
}
printf("%lld
", dp[n]);
return 0;
}
总结:
斜率优化核心有几点:首先,关键的一步是想出状态转移方程,方程没有一切都不行。
其次,斜率优化得写对。
如果横坐标不单调递增,这下子比较麻烦,得创造平衡树动态维护凸包,当然思路还是不难(转移最难了)。