欢迎访问~原文出处——博客园-zhouzhendong
去博客园看该题解
题目传送门 - BZOJ1084
题意概括
这里有一个n*m的矩阵,请你选出其中k个子矩阵,使得这个k个子矩阵分值之和最大。注意:选出的k个子矩阵不能相互重叠。
输入:第一行为n,m,k(1≤n≤100,1≤m≤2,1≤k≤10),接下来n行描述矩阵每行中的每个元素的分值(每个元素的分值的绝对值不超过32767)。
题解
注意到1<=m<=2!
如果m = 1 ,那么就是一个简单的线性dp。
我们设dp[i][j]表示在前i个里面选出k个子矩阵的最大分值。
那么分两种情况讨论:
1. 什么都不干: dp[i][j] = max(dp[i][j], dp[i-1][j])
2. 弄一个新的子矩阵: dp[i][j] = max(dp[i][j], dp[x][j - 1] + presum[i] - presum[x]) 0<=x<i
时间复杂度O(kn2)
如果 m = 2 ,那么是一个稍微复杂一点的线性dp。
我们设dp[i][j][x]表示在第一列的前i个和第二列的前j个里面选出x个子矩阵的最大分值。
那么分几种情况进行讨论:
1. 什么都不干: dp[i][j][x] = max(dp[i][j][x], dp[i - 1][j][x], dp[i][j - 1][x])
2. 在第一列弄一个新的子矩阵: dp[i][j][x] = max(dp[i][j][x], dp[y][j][x - 1] + presum[i][1] - presum[y][1]) 0<=y<i
3. 在第二列弄一个新的子矩阵: dp[i][j][x] = max(dp[i][j][x], dp[i][y][x - 1] + presum[j][2] - presum[y][2]) 0<=y<j
4. 在第一、二列弄一个宽度为2的子矩阵: dp[i][j][x] = max(dp[i][j][x], dp[y][y][x - 1] + presum[i][1] - presum[y][1] + presum[j][2] - presum[y][2]) i = j 且 0<=y<i
时间复杂度O(kn3)
代码
#include <cstring> #include <algorithm> #include <cstdlib> #include <cstdio> #include <cmath> using namespace std; const int N=100+5,M=5,K=10+5; const int Inf=1<<25; int n,m,k,a[N][M]; void solve1(){ int dp[N][K],presum[N]; for (int i=0;i<N;i++) for (int j=0;j<K;j++) dp[i][j]=-Inf; presum[0]=0; for (int i=1;i<=n;i++) presum[i]=presum[i-1]+a[i][1]; dp[0][0]=0; int ans=-Inf; for (int i=0;i<=n;i++) for (int j=0;j<=k;j++){ if (!i&&!j) continue; if (i) dp[i][j]=dp[i-1][j]; if (!j) continue; for (int x=0;x<i;x++) dp[i][j]=max(dp[i][j],dp[x][j-1]+presum[i]-presum[x]); } printf("%d",dp[n][k]); } void solve2(){ int dp[N][N][K],presum[N][M]; presum[0][1]=presum[0][2]=0; for (int i=1;i<=n;i++){ presum[i][1]=presum[i-1][1]+a[i][1]; presum[i][2]=presum[i-1][2]+a[i][2]; } for (int i=0;i<N;i++) for (int j=0;j<N;j++) for (int x=0;x<K;x++) dp[i][j][x]=-Inf; dp[0][0][0]=0; for (int i=0;i<=n;i++) for (int j=0;j<=n;j++) for (int x=0;x<=k;x++){ if (!i&&!j&&!x) continue; if (i&&j) dp[i][j][x]=max(dp[i-1][j][x],dp[i][j-1][x]); else if (i) dp[i][j][x]=dp[i-1][j][x]; else if (j) dp[i][j][x]=dp[i][j-1][x]; if (!x) continue; for (int y=0;y<i;y++) dp[i][j][x]=max(dp[i][j][x],dp[y][j][x-1]+presum[i][1]-presum[y][1]); for (int y=0;y<j;y++) dp[i][j][x]=max(dp[i][j][x],dp[i][y][x-1]+presum[j][2]-presum[y][2]); if (i==j) for (int y=0;y<i;y++) dp[i][j][x]=max(dp[i][j][x],dp[y][y][x-1]+presum[i][1]-presum[y][1]+presum[j][2]-presum[y][2]); } printf("%d",dp[n][n][k]); } int main(){ scanf("%d%d%d",&n,&m,&k); for (int i=1;i<=n;i++) for (int j=1;j<=m;j++) scanf("%d",&a[i][j]); if (m==1) solve1(); else solve2(); return 0; }