拿到这道题,考虑先简化问题,我们发现可以把 (0/1) 串缩成一个数列 (a) ,(a_i) 表示第 (i) 个 (1) 与第 (i-1) 个 (1) 有 (a_i) 个 (0),记序列 (a) 的长度为 (m)。
那么移动盾牌这一操作就变为将 (a_x) 加 (1),(a_{x+1}) 减 (1) 或是将 (a_x) 减 (1) ,(a_{x+1}) 加 (1)。而 (ans=sumlimits_{i=1}^msumlimits_{j=i+1}^m a_ia_j=frac{1}{2}left(sumlimits_{i=1}^msumlimits_{j=1}^ma_ia_j-sumlimits_{i=1}^ma_i^2 ight))。于是我们只需要求出进行了 (i) 次操作后,(min left{sumlimits_{i=1}^ma_i^2 ight}) 的值即可。
这个值显然无法贪心求出,于是考虑 ( ext{DP})。设 (f_{i,j,k}) 为序列中前 (i) 个数,进行了 (k) 次操作后,当前 (sumlimits_{x=1}^ia_x=j),所能够得到的 (minleft{sumlimits_{x=1}^i a_x^2 ight})。为了方便,我们预处理出 (s_i=sumlimits_{j=1}^ia_j)。
于是有转移式:
此处转移式的 (|l-s_{i+1}|+k),其实是一个结论:对于两个长度为 (n) 的序列 (a,b),可以对序列 (a) 的相邻两数 (+1,-1),经过操作由 (a) 得到序列 (b) 的最小操作数为 (sum_{i=1}^n left|sumlimits_{j=1}^i a_j-sumlimits_{j=1}^i b_j ight|)。前提是序列 (a) 经过操作能够得到序列 (b) ,也就是 (sum_ia_i=sum_i b_i)。感性理解这个结论是非常容易的,对当前数 (a_i) 进行操作将它变为 (b_i),会对以后的数产生影响:如果对 (a_i) 加,就是对 (a_{i+1}) 减;如果对 (a_i) 减,就是对 (a_{i+1}) 加。而转移式中的 (|l-s_{i+1}|+k) 其实就是 (left|sumlimits_{j=1}^i a_j-sumlimits_{j=1}^i b_j ight|) 的体现。
此处附上一段更为严谨的证明:
可以看作每个位置有一个东西,初始时 (a_i<b_i) 这个位置上是负值,(a_i>b_i) ,这个位置上是正值,每次可以把一个正的移动一下,那就是按顺序匹配一下,考虑每条边的贡献就是前缀和的绝对值之和 ——myh 神仙
这样直接枚举每个量刷表,时间复杂度是 (O(n^5)) 的,然而由于 (j) 这一维的特殊性,有效的 (j) 很少,所以实际时间复杂度远低于 (O(n^5))。当然,你也可以选择对这个方程进行斜率优化,可以做到稳定的 (O(n^4 log n))。
Show the Code
#include<cstdio>
#include<cstring>
const int inf=0x3f3f3f3f;
int a[85],b[85],s[85],f[85][85][3205];
inline int read() {
register int x=0,f=1;register char s=getchar();
while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
return x*f;
}
inline int min(const int &x,const int &y) {return x<y? x:y;}
inline int max(const int &x,const int &y) {return x>y? x:y;}
inline int abs(const int &x) {return x<0? -x:x;}
int main() {
int n=read(),m=1,sum=0;
for(register int i=1;i<=n;++i) a[i]=read();
for(register int i=1;i<=n;++i) {if(a[i]==0) ++b[m]; else ++m;}
for(register int i=1;i<=m;++i) {s[i]=s[i-1]+b[i];}
memset(f,0x3f,sizeof(f)); f[0][0][0]=0;
for(register int i=0;i<m;++i) {
for(register int j=0;j<=s[m];++j) {
for(register int k=0;k<=n*(n-1)/2;++k) {
if(f[i][j][k]==inf) continue;
for(register int x=j;x<=s[m];++x) {
if(k+abs(x-s[i+1])>n*(n-1)/2) continue;
int &tmp=f[i+1][x][k+abs(x-s[i+1])];
tmp=min(tmp,f[i][j][k]+(x-j)*(x-j));
}
}
}
}
int mn=inf;
for(register int i=0;i<=n*(n-1)/2;++i) {mn=min(mn,f[m][s[m]][i]);printf("%d ",(s[m]*s[m]-mn)>>1);}
return 0;
}