不会啊,看了好久的题解才看懂 TT
因为可以直接分成n段,所以就得到一个答案n,求解最小的答案,肯定是 <= n 的,
所以每一段中的不同数的个数都必须 <= sqrt(n),不然就不是最小的答案
那么
f[i]表示前i个数的最有解
g[i]表示从当前位置开始,有i个不同的数,最多能往前延伸到哪里
pre[i]表示上一个数为i的位置
cnt[i]表示g[i] + 1 ~ 当前位置 中的不同数的个数
所以 f[i] = min(f[i], f[g[j]] + j * j)
那么问题就是g数组的更新
如果 pre[x] > g[x],说明新的数在g[x] + 1 ~ 当前位置 中就包含了,不用更新,
否则g[x]就一直向后删除,直到有一种数全部删除,也就是到pre[x] <= g[x]
#include <cmath> #include <cstdio> #include <cstring> #include <iostream> #define N 40001 #define min(x, y) ((x) < (y) ? (x) : (y)) int n, m; int a[N], pre[N], cnt[N], g[N], f[N]; //f[i]表示前i个的最优解,g[i]表示数量为i最多能向左延伸到哪 inline int read() { int x = 0, f = 1; char ch = getchar(); for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1; for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0'; return x * f; } int main() { int i, j, k, x; n = read(); m = read(); m = sqrt(n); memset(f, 127, sizeof(f)); for(i = 1; i <= n; i++) a[i] = read(); f[0] = 0; for(i = 1; i <= n; i++) { for(j = 1; j <= m; j++) if(pre[a[i]] <= g[j]) cnt[j]++; pre[a[i]] = i; for(j = 1; j <= m; j++) if(cnt[j] > j) { k = g[j] + 1; while(pre[a[k]] > k) k++; g[j] = k; cnt[j]--; } for(j = 1; j <= m; j++) f[i] = min(f[i], f[g[j]] + j * j); } printf("%d ", f[n]); return 0; }