4565: [Haoi2016]字符合并
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 690 Solved: 316
[Submit][Status][Discuss]
Description
有一个长度为 n 的 01 串,你可以每次将相邻的 k 个字符合并,得到一个新的字符并获得一定分数。得到的新字
符和分数由这 k 个字符确定。你需要求出你能获得的最大分数。
Input
第一行两个整数n,k。接下来一行长度为n的01串,表示初始串。接下来2k行,每行一个字符ci和一个整数wi,ci
表示长度为k的01串连成二进制后按从小到大顺序得到的第i种合并方案得到的新字符,wi表示对应的第i种方案对应
获得的分数。1<=n<=300,0<=ci<=1,wi>=1,k<=8
Output
输出一个整数表示答案
Sample Input
3 2
101
1 10
1 10
0 20
1 30
101
1 10
1 10
0 20
1 30
Sample Output
40
//第3行到第6行表示长度为2的4种01串合并方案。00->1,得10分,01->1得10分,10->0得20分,11->1得30分。
//第3行到第6行表示长度为2的4种01串合并方案。00->1,得10分,01->1得10分,10->0得20分,11->1得30分。
Solution
区间DP+状压DP
看到$k$的范围容易想到把$k$压起来。看到是算区间贡献容易想到区间DP。
区间DP需要分段,再合并。我们可以想到把一段区间强制合并成$0/1$,与另一段合并。也就是$dp[i][mid-1][s]+dp[mid][j][0/1]$,将$mid-j$强制合并成$0/1$,因此枚举$mid$分段。
首先了解一个东西,长度为$len$的一段区间,每次相邻$k$个合并,最后剩下的区间长度是$len%(k-1)?len%(k-1):k-1$
知道了这一点我们就可以确定$i-(mid-1)$这一段的状态了,然后每次判断转移的两个状态是否合法,转移即可。
所以我们枚举区间$[i,j]$,如果长度为$k$,就直接合并获得分数,但是注意在转移的时候不要在$f$数组中直接更新。举例说明$a[1]=0 a[2]=1$,所以$f[1][2][1]$是合法的状态,假设$01$可以合并成$0$,那么就可以更新$f[1][2][0]$,但是$f[1][2][0]$一旦成了合法状态,那么在之后枚举到$t=0$的时候,$dp[1][2][0]$又会去更新别的状态,但是这样是不合法的,所以我们要用临时数组来记录值,最后赋值给$f$数组。(by)
Code
#include<bits/stdc++.h> #define LL long long using namespace std; int n, k, a[305], c[305]; char s[305]; LL dp[305][305][305], w[305]; int main() { scanf("%d%d", &n, &k); scanf("%s", s + 1); for(int i = 1; i <= n; i ++) if(s[i] == '1') a[i] = 1; else a[i] = 0; memset(dp, 128, sizeof(dp)); LL oo = -dp[0][0][0]; for(int i = 0; i < (1 << k); i ++) scanf("%d %lld", &c[i], &w[i]); for(int i = 1; i <= n; i ++) dp[i][i][a[i]] = 0; for(int len = 2; len <= n; len ++) for(int i = 1; i + len - 1 <= n; i ++) { int j = i + len - 1; int tot = (j - i) % (k - 1) ? (j - i) % (k - 1) : k - 1; for(int mid = j; mid >= i; mid -= (k - 1)) for(int s = 0; s < (1 << tot); s ++) { if(dp[i][mid - 1][s] != -oo) { if(dp[mid][j][1] != -oo) dp[i][j][s << 1 | 1] = max(dp[i][j][s << 1 | 1], dp[i][mid - 1][s] + dp[mid][j][1]); if(dp[mid][j][0] != -oo) dp[i][j][s << 1] = max(dp[i][j][s << 1], dp[i][mid - 1][s] + dp[mid][j][0]); } } if(tot == k - 1) { LL g[2]; g[0] = g[1] = -oo; for(int s = 0; s < (1 << k); s ++) if(dp[i][j][s] != -oo) g[c[s]] = max(g[c[s]], dp[i][j][s] + w[s]); dp[i][j][1] = g[1], dp[i][j][0] = g[0]; } } LL ans = 0; for(int i = 0; i < (1 << k); i ++) ans = max(ans, dp[1][n][i]); printf("%lld", ans); return 0; }