功能:一个万能的斜率优化模板 可以解决横坐标不单调 查询坐标不单调的问题
斜率优化问题解决方法:
斜率优化问题是当dp式类似$dp_i = dp_j + a_i * b_j$形式时无法左右分离i与j时的一种优化复杂度的方法。
我们通过变换得到类似$y = k * x + b$形式,其中$x$和$y$是只关于$i$的项,$k$和$b$是只关于$j$的项。同时$dp_i$在$y$中,我们希望选择合适的$k$与$b$使得$y$最大或最小,从而使dp[i]最大或最小。
于是我们需要维护一个直线集合$(k,b)$,当需要$dp_i$最大时直线集合形成下凸的凸包,反之则形成上凸的凸包,每次查询$query(x)$即可得出结果,加入直线则是$add(k,b)$,复杂度$O(nlogn)$。
模板:
namespace { struct Line { mutable ll k, m, p; bool f; // 存在斜率吗 Line() {} Line(ll _k, ll _m, ll _p, bool _f) : k(_k), m(_m), p(_p), f(_f) {} bool friend operator < (const Line &a, const Line &b) { return (a.f && b.f) ? a.k < b.k : a.p < b.p; } }; struct LineContainer : multiset<Line> { // LineContainer() {} const ll inf = LLONG_MAX; ll div(ll a, ll b) { //求交点 return a / b - (a ^ b < 0 && a % b); } ld div(ld a, ld b) { return a / b; } bool Intersect(iterator x, iterator y) { if(y == end()) { x -> p = inf; return false; } if(x -> k == y -> k) x -> p = x -> m > y -> m ? inf : -inf; else x -> p = div(y -> m - x -> m, x -> k - y -> k); return x -> p >= y -> p; } void add(ll k, ll m) { multiset<Line> :: iterator z = insert(Line(k, m, 0, 1)), y = z++, x = y; while(Intersect(y, z)) z = erase(z); if(x != begin() && Intersect(--x, y)) Intersect(x, y = erase(y)); while((y = x) != begin() && (--x) -> p >= y -> p) Intersect(x, erase(y)); } ll query(ll x) { // assert(!empty()); multiset<Line> :: iterator L = lower_bound(Line(0, 0, x, 0)); return L -> k * x + L -> m; } }; }
例题:
Codeforces 631E
大意:使得$sum_{i=1}^{n}{a_{i}*i}$最大,可以将一个$a_i$插入到任何一个位置
题解:需要两次$dp$,没有差别,仅说明第一遍$dp$。
枚举位置$i$,考虑移到位置$j$前面且 $j leq i$。$ans = max(tot + a_{i}*j-a_{i}*i+sum_{i-1}-sum_{j-1})$。化成$y = k * x + b$形式,得出$ans-tot-sum_{i-1}+a_{i}*i=a_{i}*j-sum_{j-1}$
所以
$k = j$
$b = -sum_{j-1}$
$x = a_i$
$y = ans - tot - sum_{i-1} + a_{i} * i$
维护下凸壳,具体见代码。答案即是$query(a_{i})+tot+sum_{i-1}-a_{i}*i$。
#include <cstdio> #include <cstring> #include <algorithm> #include <set> #include <vector> using namespace std; typedef long long ll; typedef long double ld; namespace { struct Line { mutable ll k, m, p; bool f; // 存在斜率吗 Line() {} Line(ll _k, ll _m, ll _p, bool _f) : k(_k), m(_m), p(_p), f(_f) {} bool friend operator < (const Line &a, const Line &b) { return (a.f && b.f) ? a.k < b.k : a.p < b.p; } }; struct LineContainer : multiset<Line> { // LineContainer() {} const ll inf = LLONG_MAX; ll div(ll a, ll b) { //求交点 return a / b - (a ^ b < 0 && a % b); } ld div(ld a, ld b) { return a / b; } bool Intersect(iterator x, iterator y) { if(y == end()) { x -> p = inf; return false; } if(x -> k == y -> k) x -> p = x -> m > y -> m ? inf : -inf; else x -> p = div(y -> m - x -> m, x -> k - y -> k); return x -> p >= y -> p; } void add(ll k, ll m) { multiset<Line> :: iterator z = insert(Line(k, m, 0, 1)), y = z++, x = y; while(Intersect(y, z)) z = erase(z); if(x != begin() && Intersect(--x, y)) Intersect(x, y = erase(y)); while((y = x) != begin() && (--x) -> p >= y -> p) Intersect(x, erase(y)); } ll query(ll x) { // assert(!empty()); multiset<Line> :: iterator L = lower_bound(Line(0, 0, x, 0)); return L -> k * x + L -> m; } }; } const int maxn = 5e5 + 5; int n; ll tot, ans; ll a[maxn], sum[maxn]; LineContainer H; int main() { scanf("%d", &n); for(int i = 1; i <= n; ++i) { scanf("%lld", &a[i]); tot += a[i] * i; sum[i] = sum[i - 1] + a[i]; } ans = tot; for(int i = 1; i <= n; ++i) { if(!H.empty()) ans = max(ans, H.query(a[i]) + sum[i - 1] + tot - a[i] * i); H.add(1.0 * i, 1.0 * -sum[i - 1]); } H.clear(); for(int i = n; i; --i) { if(!H.empty()) ans = max(ans, H.query(a[i]) + sum[i] + tot - a[i] * i); H.add(1.0 * i, 1.0 * -sum[i]); } printf("%lld ", ans); return 0; } /* 2 3 4 5 6 5 2 3 4 6 j = 2 i = 5 前后做两遍 前: 向前移i移到j前面 ans = max(tot + a[i] * j - a[i] * i + (sum[i - 1] - sum[j - 1])) = max(tot + a[i] * j - sum[i - 1] - a[i] * i + sum[j - 1]) j <= i y = k * x + b y 最大 ans - tot - sum[i - 1] + a[i] * i = a[i] * j - sum[j - 1] x = a[i] y = ans - tot - sum[i - 1] + a[i] * i k = j b = -sum[j - 1] 维护下凸包 后: 2 3 4 5 6 3 4 5 2 6 ans = max(tot + a[i] * j - a[i] * i + (sum[i] - sum[j])) */