题目大意:
给你一个长度为n的序列a,你可以将其分为若干段,最终的答案为每一段不同数个数的平方和。
思路:
不难想到一个O(n^2)的DP:
f[i]=min{f[j]+cnt(j,i)^2}
考虑一些优化。
首先不难发现,答案最坏不会超过n。(一个数一段)
要让答案更优,一段内不同数的个数不会超过sqrt(n)。(不然平方之后就超过n了)。
我们把到i有j个不同数的最后位置记作pos[j]。
考虑如何维护这个pos[j]。
我们可以先将每个数字出现的上一个位置记作last[i],一段中出现的不同数字个数记作cnt[i]。
首先,如果last[a[i]]<=pos[j],则cnt[j]++。
当cnt[j]>j时,把左端点往右缩,如果这时last[a[pos[j]]]==pos[j],cnt[j]--。
1 #include<cmath> 2 #include<cstdio> 3 #include<cctype> 4 #include<algorithm> 5 inline int getint() { 6 register char ch; 7 while(!isdigit(ch=getchar())); 8 register int x=ch^'0'; 9 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 10 return x; 11 } 12 const int inf=0x7fffffff; 13 const int N=40001,M=40001; 14 int a[N],f[N],pos[M],last[M],cnt[M]; 15 int main() { 16 int n=getint(),m=getint(); 17 for(register int i=1;i<=n;i++) { 18 const int x=getint(); 19 if(x!=a[a[0]]) a[++a[0]]=x; 20 } 21 n=a[0],m=sqrt(a[0]); 22 for(register int i=1;i<=n;i++) { 23 f[i]=inf; 24 for(register int j=1;j<=m;j++) { 25 if(last[a[i]]<=pos[j]) cnt[j]++; 26 } 27 last[a[i]]=i; 28 for(register int j=1;j<=m;j++) { 29 while(cnt[j]>j) { 30 pos[j]++; 31 if(last[a[pos[j]]]==pos[j]) cnt[j]--; 32 } 33 } 34 for(register int j=1;j<=m;j++) { 35 f[i]=std::min(f[i],f[pos[j]]+j*j); 36 } 37 } 38 printf("%d ",f[n]); 39 return 0; 40 }