题意
有 (n) 个非负整数 (a_i) (( n le 4 * 10^7)), 将它们分为若干部分, 记为 (S_i), 要求 (S_{i+1} ge S_i),
设 (res=sum_{i=1}^{k} S_i^2).
求 (res) 的最小值
思路
64 pts
考场上写的做法.
设 (f[i][j]) 为第一段的起点为 (i), 终点为 (j) 时 (res) 的最小值.
朴素做法直接枚举 (i,j,k), 将 (f[i][j]) 从 (f[j][k]) 转移过来.
优化:
考虑中间的转移点 (j), 对于每个 (i<j), 合法的 (k) 的范围是递增的, 所以可以只枚举 (i,j) 然后指针 (k) 根据 (sum_j-sum_{i-1}) 不断忘前移, 并取最小值更新即可.
88 pts
设总共有 (k) 个块, 可以得出两个性质.
性质 1 : (k) 越大, 方案越优.
感性理解 : (a^2+b^2 le (a+b)^2).
进一步推断出,
性质 2 : (S_k) 越小, 方案越优.
感性理解 : (S_k) 越小, 为了使得方案合法, 就强迫前面的 (S) 也要尽量地小, 由于 (a) 是非负整数, 所以 (S) 的值越小, (k) 就越大, 根据性质 1, 方案就越优.
这样的话我们就可以维护一个单调队列, 对于每一位, 找到以它为结束点时, 最小的 (S_k).
在把新元素加入单调队列时还有一些细节, 详见代码.
100pts
上面的做法加个高精 (高精这种东西早就忘光了好吗...)
代码
64pts
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=5e3+7;
const ll inf=5e18;
int n,ty;
ll a[N],sum[N],f[N][N],ans=inf;
void read(){ }
ll p2(ll x){ return x*x; }
int main(){
//freopen("divd.in","r",stdin);
cin>>n>>ty;
if(ty) read();
else for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); sum[i]=sum[i-1]+a[i]; }
memset(f,127,sizeof(f));
for(int i=1;i<=n;i++) f[i][n]=p2(sum[n]-sum[i-1]);
for(int j=n;j>=1;j--){
int k=n;
ll minx=inf;
for(int i=1;i<=j;i++){
while(p2(sum[j]-sum[i-1])<=p2(sum[k]-sum[j])){
minx=min(minx,f[j+1][k]);
k--;
}
f[i][j]=min(f[i][j],minx+p2(sum[j]-sum[i-1]));
}
}
for(int i=1;i<=n;i++) ans=min(ans,f[1][i]);
printf("%lld
",ans);
return 0;
}
88pts
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=4e7+7;
int n,ty,pre[N],que[N],t1,t2;
ll a[N],sum[N],lst[N],ans;
void read(){};
ll p2(ll x){ return x*x; }
int main(){
// freopen("divd.in","r",stdin);
// freopen("x.out","w",stdout);
cin>>n>>ty;
if(ty) read();
else for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); sum[i]=sum[i-1]+a[i]; }
t1=t2=1;
que[1]=0;
for(int i=1;i<=n;i++){
// printf("%d: ",i); for(int j=t1;j<=t2;j++) printf("%d ",que[j]); putchar('
');
while(t1<t2&&lst[que[t1+1]]<=sum[i]-sum[que[t1+1]]) t1++;
pre[i]=que[t1];
lst[i]=sum[i]-sum[pre[i]];
while(t2>=t1&&lst[que[t2]]>=lst[i]+sum[i]-sum[que[t2]]) t2--;
que[++t2]=i;
}
int p=n;
while(p){
ans+=lst[p]*lst[p];
p=pre[p];
}
// for(int i=1;i<=n;i++) printf("pre[%d]: %d
",i,pre[i]);
printf("%lld
",ans);
return 0;
}