D - The Bakery
这个题目好难啊,我理解了好久,都没有怎么理解好,
这种线段树优化dp,感觉还是很难的。
直接说思路吧,说不清楚就看代码吧。
这个题目转移方程还是很好写的,
dp[i][j]表示前面 i 个蛋糕 分成了 j 个数字的最大价值。
dp[i][j]=max(dp[k][j-1]+val[k+1~i])
显而易见的是,这个肯定不可以直接暴力求,所以就要用到线段树优化。
线段树怎么优化呢,
先看这个问题,给你一个点 x ,问你以这个点为右端点的所有区间有多少种数字,
这个很简单是不是,那继续问你 从x 到 x+1 这个点怎么转移?
是不是找到 last[a[x+1]] 上一次出现a[x+1] 这个数字的位置,从这个位置+1到 x+1 这个位置,所有的区间都+1
这个是不是就是线段树的更新,那么线段树的每一个位置是不是随着我们对 i 的枚举,每一个叶子节点 就是l==r==k 是不是 val[k~i]
知道这个了,回到之前的问题,我们要求val[k+1~i]+dp[k][j-1]的最大值
因为这个dp[k][j-1]上一次已经求出来了,对这一次不产生任何影响了,是一个定值。
我们就只需要求val[k+1~j]
所以可以把这两个东西一起放到线段树里面,但是一个是l==r==k这个位置,一个是k+1这个位置,所以需要val往前面挪一下,或者dp[k]往后挪一下。
我选择第一种,那么就是每次更新,就更新 last[a[x+1]] 到 x 这个位置。
#include <cstdio> #include <cstring> #include <algorithm> #include <iostream> #include <algorithm> #include <cstdlib> #include <vector> #include <stack> #include <queue> #include <map> #include <string> #define inf 0x3f3f3f3f #define inf64 0x3f3f3f3f3f3f3f3f using namespace std; typedef long long ll; const int maxn = 4e4+ 10; int maxs[maxn * 4], lazy[maxn * 4]; int dp[maxn]; void push_up(int id) { maxs[id] = max(maxs[id << 1], maxs[id << 1 | 1]); } void build(int id,int l,int r) { lazy[id] = 0; maxs[id] = 0; if(l==r) { maxs[id] = dp[l]; return; } int mid = (l + r) >> 1; build(id << 1, l, mid); build(id << 1 | 1, mid + 1, r); push_up(id); } void push_down(int id) { if (lazy[id] == 0) return; maxs[id << 1] += lazy[id]; maxs[id << 1 | 1] += lazy[id]; lazy[id << 1] += lazy[id]; lazy[id << 1 | 1] += lazy[id]; lazy[id] = 0; } void update(int id,int l,int r,int x,int y,int val) { // printf("id=%d l=%d r=%d x=%d y=%d val=%d ", id, l, r, x, y, val); if(x<=l&&y>=r) { maxs[id] += val; lazy[id] += val; return; } push_down(id); int mid = (l + r) >> 1; if (x <= mid) update(id << 1, l, mid, x, y, val); if (y > mid) update(id << 1 | 1, mid + 1, r, x, y, val); push_up(id); } int query(int id,int l,int r,int x,int y) { if (x <= l && y >= r) return maxs[id]; push_down(id); int ans = 0, mid = (l + r) >> 1; if (x <= mid) ans = max(ans, query(id << 1, l, mid, x, y)); if (y > mid) ans = max(ans, query(id << 1 | 1, mid + 1, r, x, y)); return ans; } int last[maxn]; int a[maxn]; int main() { int n, k; scanf("%d%d", &n, &k); for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); } for(int j=1;j<=k;j++) { memset(last, 0, sizeof(last)); build(1, 0, n); for(int i=1;i<=n;i++) { update(1, 0, n, last[a[i]], i - 1, 1); last[a[i]] = i; dp[i] = query(1, 0, n, 0, i - 1); } } printf("%d ", dp[n]); return 0; }