题目大意
将一个长度为 $n$ 的序列分为 $k$ 段
使得总价值最大,一段区间的价值表示为区间内不同数字的个数。
$n leq 35000, k leq 50$。
Solution
线段树优化区间 $dp$ 好题。
我们先来想 $dp$ 部分。
考虑子状态 $dp[i][j]$ 代表前 $i$ 个值分成 $j$ 段所得大的最大价值。
则有 $dp[i][j]=max{dp[l][j-1]+val(l+1,i)},0 leq l < i$。$val(x,y)$ 代表 $[x,y]$ 之间不同的数的个数。
而这样写的时间复杂度是 $O(n^2 imes k)$ 的。
我们考虑如何去优化它。
我们发现段数 $j$ 的值只与 $j-1$ 有关,所以我们不妨先枚举段数。
再来看我们的转移方程 “$dp[i][j]=max{dp[l][j-1]+val(l+1,i)},0 leq l < i$。$val(x,y)$”,其实就是取一个区间最大值。
不过这里有一个 $val(l+1,i)$ 非常的讨厌。
我们来研究它的值。设当前值是 $a[i]$。
因为现在枚举的位置已经到 $i$,所以 $val(l+1,i-1)$ 的值是我们已知的。
如果区间 $[l+1,i-1]$ 内已经有 $a[i]$ 了,那么就不会产生贡献。我们记录 $a[i]$ 这个值上一次出现的位置记为 $pre[i]$。
那么只有属于 $[pre[i],i)$ 的 $l$ 的 $val(l+1,i)$ 相比 $val(l+1,i-1)$ 才会加 $1$。
接下来看如何用线段树维护。
其实很简单。
线段树上的节点 $[x, y]$(这里用区间来代表节点)存的就是所有 $l$ 属于 $[x,y]$ 的最大的 $dp[l][j-1]+val(l+1,i)$。
考虑现在枚举到 $i$,就把所有 $[pre[i],i)$ 的值加 $1$。转移就是 $[0,i)$ 中的最大值。
区间修改+区间查询搞定。
上代码。
#include <iostream> #include <cstdio> using namespace std; struct Segment{ int val, tag; }st[35010 << 2]; int n, k; int a[35010]; int pre[35010], head[35010]; int dp[35010]; void push_down(int p) { if (st[p].tag) { st[p << 1].val += st[p].tag; st[p << 1].tag += st[p].tag; st[p << 1 | 1].val += st[p].tag; st[p << 1 | 1].tag += st[p].tag; st[p].tag = 0; } } void build(int p, int l, int r) { st[p].tag = 0; if (l == r) { st[p].val = dp[l]; return; } int mid = (l + r) >> 1; build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r); st[p].val = max(st[p << 1].val, st[p << 1 | 1].val); } void change(int p, int l, int r, int L, int R) { if (L <= l && r <= R) { st[p].tag++; st[p].val++; return; } push_down(p); int mid = (l + r) >> 1; if (L <= mid) change(p << 1, l, mid, L, R); if (mid < R) change(p << 1 | 1, mid + 1, r, L, R); st[p].val = max(st[p << 1].val, st[p << 1 | 1].val); } int ask(int p, int l, int r, int L, int R) { if (L <= l && r <= R) { return st[p].val; } push_down(p); int mid = (l + r) >> 1, ret = 0; if (L <= mid) ret = max(ret, ask(p << 1, l, mid, L, R)); if (mid < R) ret = max(ret, ask(p << 1 | 1, mid + 1, r, L, R)); return ret; } int main() { scanf("%d%d", &n, &k); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i <= n; i++) { pre[i] = head[a[i]]; head[a[i]] = i; } for (int j = 1; j <= k; j++) { build(1, 0, n); for (int i = j; i <= n; i++) { change(1, 0, n, pre[i], i - 1); dp[i] = ask(1, 0, n, 0, i - 1); } } cout << dp[n]; return 0; }