https://acm.hdu.edu.cn/showproblem.php?pid=6981
题意:
给出2个n*n的矩阵A和B
起点在(1,1),终点在(n,n),每步只能往右或者往下走
得分为路径上的A的和与B的和的乘积
问最大得分
数据随机
解法一:搜索剪枝
因为是随机数据,估价函数优秀一些大概率还是可以过的
参考的这位大佬的做法
https://blog.csdn.net/George_Plover/article/details/119151411
设现在走过的A矩阵的和为sa,将来要走的和为a,现在走过的B矩阵的和为sb,将来要走的和为b
那么总乘积和可以表示为(sa+a)*(sb+b)=sa*sb+sa*b+sb*a+a*b
令C=A+B
则乘积和=sa*sb+sa*(c-a)+sb*a+a*(c-a)
假设当前位置右下角的数我们可以自由调整
c越大,乘积和越大
令Cmax表示当前位置右下角还能获取的最大的A+B之和
式子sa*sb+sa*(Cmax-a)+sb*a+a*(Cmax-a)是一个二元一次方程
当a=(sb-sa+Cmax)/2 时,方程的值最大,Cmax-a可求出对应的b
就可以用这个a b Cmax做最优化剪枝
#include<bits/stdc++.h> using namespace std; #define N 101 int n; int a[N][N],b[N][N],c[N][N]; long long ans; void dfs(int x,int y,int sa,int sb) { if(x==n && y==n) { ans=max(ans,1ll*sa*sb); return; } int aa,bb; if(x<n) { aa=sb-sa+c[x+1][y]>>1; bb=sa+c[x+1][y]-sb>>1; if(1ll*(sa+aa)*(sb+bb)>ans) dfs(x+1,y,sa+a[x+1][y],sb+b[x+1][y]); } if(y<n) { aa=sb-sa+c[x][y+1]>>1; bb=sa+c[x][y+1]-sb>>1; if(1ll*(sa+aa)*(sb+bb)>ans) dfs(x,y+1,sa+a[x][y+1],sb+b[x][y+1]); } } int main() { int T; scanf("%d",&T); while(T--) { scanf("%d",&n); for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) scanf("%d",&a[i][j]); for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) scanf("%d",&b[i][j]); for(int i=n;i;--i) for(int j=n;j;--j) { c[i][j]=a[i][j]+b[i][j]; if(i<n) c[i][j]=max(c[i][j],a[i][j]+b[i][j]+c[i+1][j]); if(j<n) c[i][j]=max(c[i][j],a[i][j]+b[i][j]+c[i][j+1]); } ans=0; dfs(1,1,a[1][1],b[1][1]); printf("%lld ",ans); } }
解法二:
std做法
令f[i][j][k]表示走到(i,j)A矩阵之和为k时,最大的B矩阵之和
当k1<=k2时,必须f[i][j][k2]>f[i][j][k1]
出题人说剔除掉无用状态后,k就只有几千个
然而不知道为啥
dp过程中剔除状态的方式值得学习
#include<bits/stdc++.h> using namespace std; #define N 101 int n; int a[N][N],b[N][N]; #define pr pair<int,int> #define mp make_pair vector<pr>f[N][N]; int cnt; pr tmp[1000003]; void add(pr d) { while(cnt &&d.second>=tmp[cnt].second) --cnt; if(!cnt || d.first>tmp[cnt].first) tmp[++cnt]=d; } void merge(vector<pr>&to,vector<pr>x,vector<pr>y) { int s1=x.size(),s2=y.size(),m=0,i=0,j=0; cnt=0; while(i<s1 && j<s2) add(x[i].first<y[j].first ? x[i++]:y[j++]); while(i<s1) add(x[i++]); while(j<s2) add(y[j++]); to.resize(cnt); for(int i=1;i<=cnt;++i) to[i-1]=tmp[i]; } int main() { int T,s; long long ans; scanf("%d",&T); while(T--) { scanf("%d",&n); for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) scanf("%d",&a[i][j]); for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) scanf("%d",&b[i][j]); f[1][1].clear(); f[1][1].push_back(mp(a[1][1],b[1][1])); for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) { if(i==1 && j==1) continue; else if(i==1) f[i][j]=f[i][j-1]; else if(j==1) f[i][j]=f[i-1][j]; else merge(f[i][j],f[i-1][j],f[i][j-1]); s=f[i][j].size(); for(int k=0;k<s;++k) { f[i][j][k].first+=a[i][j]; f[i][j][k].second+=b[i][j]; } } s=f[n][n].size(); ans=0; for(int i=0;i<s;++i) ans=max(ans,1ll*f[n][n][i].first*f[n][n][i].second); printf("%lld ",ans); } }