Description
Input
输入第一行包含两个整数n,k(k+1≤n)。
Output
输出第一行包含一个整数,为小H可以得到的最大分数。
Sample Input
4 1 3 4 0 2 3
Sample Output
HINT
【样例说明】
在样例中,小H可以通过如下3轮操作得到108分:
1.-开始小H有一个序列(4,1,3,4,0,2,3)。小H选择在第1个数之后的位置将序列分成两部分,并得到4×(1+3+4+0+2+3)=52分。
2.这一轮开始时小H有两个序列:(4),(1,3,4,0,2,3)。小H选择在第3个数字之后的位置将第二个序列分成两部分,并得到(1+3)×(4+0+2+ 3)=36分。
3.这一轮开始时小H有三个序列:(4),(1,3),(4,0,2,3)。小H选择在第5个数字之后的位置将第三个序列分成两部分,并得到(4+0)×(2+3)= 20分。
经过上述三轮操作,小H将会得到四个子序列:(4),(1,3),(4,0),(2,3)并总共得到52+36+20=108分。
【数据规模与评分】
数据满足2≤n≤100000,1≤k≤min(n -1,200)。
我们发现如果我们把序列分成4个部分 x1、x2、x3、x4
那么我们无论按照什么顺序分割,最后的答案都是 x1 * x2 + x1 * x3 + x1 * x4 + x2 * x3 + x2 * x4 + x3 * x4
你会发现最终答案与分割顺序无关,只与分割位置有关。
关于要分割 k 次的 k 怎么处理,当时我问老师:这个需要开二维dp数组然后维护k个凸壳吗? 老师说:k个也可以。。。
其实直接跑k次dp就可以了,因为每次要利用dp[j][kk-1],我们用g来代替kk-1状态的dp数组,f来表示kk状态的dp数组,用g更新f,然后下一次dp时(kk+1状态)直接把f复制到g里。
接下来就是推式子斜率优化啦
每次从左到右多加一个分割点,答案相当于 从 x1 * x2 + x1 * x3 + x2 * x3 变到 x1 * x2 + x1 * x3 + x1 * x4 + x2 * x3 + x2 * x4 + x3 * x4
多了 x4 * (x1 + x2 + x3)
用c表示前缀和
则没加一个分割点,多了 (c[i] - c[j]) * c[j]
f[i] = g[j] + (c[i] - c[j]) * c[j]
=> f[i] = g[j] + c[i] * c[j] - c[j]2
=> f[i] - c[i] * c[j] = g[j] - c[j]2
当然当时我比较傻,当时把 j 作为上一个这一段区间的左端点定义的,而不是按上一个分割点定义的,所以代码里一堆 j-1就是 式子里的 j 啦
#include<algorithm> #include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #include<cmath> using namespace std; const int maxn=1e5+10,maxk=210; long long n,k,c[maxn],f[maxn],g[maxn],zz[maxn]; long long aa,fl;char cc; long long read(){ aa=0;cc=getchar();fl=1; while((cc<'0'||cc>'9')&&cc!='-') cc=getchar(); if(cc=='-') fl=-1,cc=getchar(); while(cc>='0'&&cc<='9') aa=aa*10+cc-'0',cc=getchar(); return aa*fl; } bool ok(int x,int y,int z) { return (-c[x-1]*c[x-1]+c[y-1]*c[y-1]+g[x-1]-g[y-1])*(c[z-1]-c[y-1])>(-c[y-1]*c[y-1]+c[z-1]*c[z-1]+g[y-1]-g[z-1])*(c[y-1]-c[x-1]); } void dp(long long x){ int l=1,r=0; for(int i=x;i<=n;++i) { while(l<r&&ok(zz[r-1],zz[r],i)) r--; zz[++r]=i; // printf("%d ",r); while(l<r&&c[i]*c[zz[l]-1]-c[zz[l]-1]*c[zz[l]-1]+g[zz[l]-1]<c[i]*c[zz[l+1]-1]-c[zz[l+1]-1]*c[zz[l+1]-1]+g[zz[l+1]-1]) l++; // printf("%d ",l); f[i]=c[i]*c[zz[l]-1]-c[zz[l]-1]*c[zz[l]-1]+g[zz[l]-1]; } memcpy(g,f,sizeof(f)); } int main() { n=read();k=read(); for(int i=1;i<=n;++i) { c[i]=read(); if(c[i])c[i]+=c[i-1]; else i--,n--; } for(int i=1;i<=k;++i) dp(i); printf("%lld",g[n]); return 0; }