毒瘤题……少了一个判断检查了(40min)
看到题目 首先有个 (O(n^2k)) 的 DP 做法
就是
设 (sum(x)) = (sum_{i=1}^{x}a_i)
(f_{i,j})=(max){(f_{i-1,k} + sum(k) * (sum(j) - sum(k)))}
这样直接预处理出(sum_x=sum_{i=1}^{x}) 即前缀和 然后 (O(n^2k)) 就可以做出来了
考虑斜率优化
假设 (x) 优于 (y)
那么
(f_{i-1,x}+sum_x*(sum_j-sum_x) > f_{i-1.y}+sum_y*(sum_j-sum_y))
展开得到
(f_{i-1,x}+sum_x*sum_j-sum_x^2 > f_{i-1,y} + sum_y*sum_j - sum_y^2)
移项得到
((sum_x-sum_y)*sum_j > f_{i-1,y}-sum_y^2-f_{i-1,x}+sum_x^2)
化简得到
(sum_j > frac{f_{i-1,y}-sum_y^2-f_{i-1,x}+sum_x^2}{sum_x-sum_y})
设 (S(x)=sum_x^2-f_{i-1,x})
(sum_j>frac{S(x)-S(y)}{sum_x-sum_y})
愉快的开始斜率优化…
但是需要特判(sum_x==sum_y)的情况
// Isaunoya
#include<bits/stdc++.h>
using namespace std ;
#define int long long
#define fi first
#define se second
#define pb push_back
inline int read() {
register int x = 0 , f = 1 ;
register char c = getchar() ;
for( ; ! isdigit(c) ; c = getchar()) if(c == '-') f = -1 ;
for( ; isdigit(c) ; c = getchar()) x = (x << 1) + (x << 3) + (c & 15) ;
return x * f ;
}
template < typename T > inline bool cmax(T & x , T y) {
return x < y ? (x = y) , 1 : 0 ;
}
template < typename T > inline bool cmin(T & x , T y) {
return x > y ? (x = y) , 1 : 0 ;
}
inline int QP(int x , int y , int Mod){ int ans = 1 ;
for( ; y ; y >>= 1 , x = (x * x) % Mod)
if(y & 1) ans = (ans * x) % Mod ;
return ans ;
}
int n , k ;
const int N = 1e5 + 10 ;
const int M = 205 ;
int sum[N] , pre[M][N] ;
int q[N] ;
int f[N] , g[N] ;
inline int S(int x) {
return (sum[x] * sum[x] - g[x]) ;
}
inline double slope(int x , int y) {
if(sum[x] == sum[y]) return -1e18 ;
return 1.0 * (S(x) - S(y)) / (sum[x] - sum[y]) ;
}
signed main() {
n = read() ; k = read() ;
for(register int i = 1 ; i <= n ; i ++) sum[i] = sum[i - 1] + read() ;
memset(f , 0 , sizeof(f)) ; memset(g , 0 , sizeof(g)) ;
for(register int i = 1 ; i <= k ; i ++) {
int h = 1 , t = 0 ;
q[++ t] = 0 ;
for(register int j = 1 ; j <= n ; j ++) {
while(h < t && slope(q[h] , q[h + 1]) <= sum[j]) ++ h ;
f[j] = g[q[h]] + sum[q[h]] * (sum[j] - sum[q[h]]) ;
pre[i][j] = q[h] ;
while(h < t && slope(q[t - 1] , q[t]) >= slope(q[t] , j)) -- t ;
q[++ t] = j ;
}
memcpy(g , f , sizeof(f)) ;
} printf("%lld
" , f[n]) ;
for(register int i = n ; k ; -- k)
printf("%lld " , i = pre[k][i]) ;
return 0 ;
}