(n imes m) 的网格中,在第 (i) 行 (j) 列有 (a[i][j]) 个泡泡,每次可以收割一行或一列的泡泡,最多收割 (4) 次,问最多可以收割到多少泡泡。(nm leq 10^5)
Solution
讨论答案的各种情况
- 四行,这种情况下直接求和取前 (4) 个最大值即可
- 三行一列,枚举取哪一列,然后每次暴力提取前 (3) 个行最大值
- 两行两列,显然 (n,m) 中必有一个 (leq sqrt{10^5}),设它是行,则暴力枚举选哪两行,然后仍然按照前述方法计算答案即可
其余情况可以由上面三种基本情况旋转得到
复杂度 (mathcal{O} (nm min(n,m)))
#include <bits/stdc++.h>
using namespace std;
#define int long long
int n,m,**a,**b;
int solve1a() {
int ans=0;
vector <int> v;
for(int i=1;i<=n;i++) {
int sum=0;
for(int j=1;j<=m;j++) sum+=a[i][j];
v.push_back(sum);
}
sort(v.begin(),v.end());
for(int i=0;i<4;i++) {
if(v.size()) ans+=v.back(), v.pop_back();
}
return ans;
}
int solve1b() {
int ans=0;
vector <int> v;
for(int i=1;i<=m;i++) {
int sum=0;
for(int j=1;j<=n;j++) sum+=a[j][i];
v.push_back(sum);
}
sort(v.begin(),v.end());
for(int i=0;i<4;i++) {
if(v.size()) ans+=v.back(), v.pop_back();
}
return ans;
}
int solve2a() {
int ans=0;
int *sum,*tmp;
sum=new int[n+1];
tmp=new int[n+1];
for(int i=1;i<=n;i++) sum[i]=0;
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
sum[i]+=a[i][j];
}
}
for(int k=1;k<=m;k++) {
int tot=0;
for(int i=1;i<=n;i++) tot+=a[i][k];
for(int i=1;i<=n;i++) {
tmp[i]=sum[i]-a[i][k];
}
for(int i=0;i<3;i++) {
tot+=*max_element(tmp+1,tmp+n+1);
*max_element(tmp+1,tmp+n+1)=0;
}
ans=max(ans,tot);
}
return ans;
}
int solve2b() {
int ans=0;
int *sum,*tmp;
sum=new int[m+1];
tmp=new int[m+1];
for(int i=1;i<=m;i++) sum[i]=0;
for(int i=1;i<=m;i++) {
for(int j=1;j<=n;j++) {
sum[i]+=a[j][i];
}
}
for(int k=1;k<=n;k++) {
int tot=0;
for(int i=1;i<=m;i++) tot+=a[k][i];
for(int i=1;i<=m;i++) tmp[i]=sum[i]-a[k][i];
for(int i=0;i<3;i++) {
tot+=*max_element(tmp+1,tmp+m+1);
*max_element(tmp+1,tmp+m+1)=0;
}
ans=max(ans,tot);
}
return ans;
}
int solve3a() {
int ans=0;
int *sum,*tmp;
sum=new int[m+1];
tmp=new int[m+1];
for(int i=1;i<=m;i++) sum[i]=0;
for(int i=1;i<=m;i++) {
for(int j=1;j<=n;j++) {
sum[i]+=a[j][i];
}
}
for(int k=1;k<=n;k++) {
for(int l=1;l<=n;l++) if(k!=l) {
int tot=0;
for(int i=1;i<=m;i++) tot+=a[k][i]+a[l][i];
for(int i=1;i<=m;i++) tmp[i]=sum[i]-a[k][i]-a[l][i];
for(int i=0;i<2;i++) {
tot+=*max_element(tmp+1,tmp+m+1);
*max_element(tmp+1,tmp+m+1)=0;
}
ans=max(ans,tot);
}
}
return ans;
}
int solve3b() {
int ans=0;
int *sum,*tmp;
sum=new int[n+1];
tmp=new int[n+1];
for(int i=1;i<=n;i++) sum[i]=0;
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
sum[i]+=a[i][j];
}
}
for(int k=1;k<=m;k++) {
for(int l=1;l<=m;l++) if(k!=l) {
int tot=0;
for(int i=1;i<=n;i++) tot+=a[i][k]+a[i][l];
for(int i=1;i<=n;i++) tmp[i]=sum[i]-a[i][k]-a[i][l];
for(int i=0;i<2;i++) {
tot+=*max_element(tmp+1,tmp+n+1);
*max_element(tmp+1,tmp+n+1)=0;
}
ans=max(ans,tot);
}
}
return ans;
}
signed main() {
ios::sync_with_stdio(false);
cin>>n>>m;
a=new int*[n+1];
b=new int*[n+1];
for(int i=0;i<=n;i++) {
a[i]=new int[m+1];
b[i]=new int[m+1];
}
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
cin>>a[i][j];
b[i][j]=a[i][j];
}
}
int ans=0;
ans=max(ans,solve1a());
ans=max(ans,solve1b());
ans=max(ans,solve2a());
ans=max(ans,solve2b());
if(n<m) ans=max(ans,solve3a());
else ans=max(ans,solve3b());
cout<<ans<<endl;
}