题目链接
题目思路
首先要想到一个dp模型
\(dp[i][0]\)为右端点为i前一个子区间小于等于k的最大贡献
\(dp[i][1]\)为右端点为i前一个子区间大于k的最大贡献
那么这样利用前缀和可以写出一个\(O(n^2)\)的dp
暴力dp
#include<bits/stdc++.h>
#define fi first
#define se second
#define debug cout<<"I AM HERE"<<endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int maxn=1e5+5,inf=0x3f3f3f3f,mod=1e9+7;
const double eps=1e-3;
const ll INF=0x3f3f3f3f3f3f3f3f;
int n,k;
int a[maxn];
ll dp[maxn][2],pre[maxn];
signed main(){
scanf("%d%d",&n,&k);
for(int i=0;i<=n;i++){
dp[i][0]=dp[i][1]=-INF;
}
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
pre[i]=pre[i-1]+a[i];
}
dp[0][0]=0;//
for(int i=1;i<=n;i++){
for(int j=0;j<=i-k-1;j++){ //[i-k,i]
dp[i][1]=max({dp[i][1],dp[j][0]+pre[i]-pre[j],dp[j][1]+pre[i]-pre[j]});
}
for(int j=max(i-k,0);j<=i-1;j++){ //[j+1,i] i-j
dp[i][0]=max({dp[i][0],dp[j][0]+pre[i]-pre[j],dp[j][1]+2*(pre[i]-pre[j])});
}
}
printf("%lld\n",max(dp[n][1],dp[n][0]));
return 0;
}
仔细观察这个dp
你会发现每次dp只需要求
区间\((0,i-k+1)\)中的\(\max(dp[j][0]-pre[j],dp[j][1]-pre[j])\)
以及
区间\((\max(i-k,0),i-1)\)中的\(\max(dp[j][0]-pre[j],dp[j][1]-2*pre[j])\)
这样只要维护三个线段树即可
代码
#include<bits/stdc++.h>
#define fi first
#define se second
#define debug cout<<"I AM HERE"<<endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const int maxn=1e5+5,inf=0x3f3f3f3f,mod=1e9+7;
const double eps=1e-3;
const ll INF=0x3f3f3f3f3f3f3f3f;
int n,k;
int a[maxn];
ll dp[maxn][2],pre[maxn];
// dp[i][0]右端点为i前一个子区间小于等于k
// dp[i][1]右端点为i前一个子区间大于k
ll tree[maxn<<2][4];
void build(int node,int l,int r){
if(l==r){
tree[node][1]=tree[node][2]=tree[node][3]=-INF;
if(l==0){
tree[node][1]=0;
}
return ;
}
int mid=(l+r)/2;
build(node<<1,l,mid);
build(node<<1|1,mid+1,r);
tree[node][1]=max(tree[node<<1][1],tree[node<<1|1][1]);
tree[node][2]=max(tree[node<<1][2],tree[node<<1|1][2]);
tree[node][3]=max(tree[node<<1][3],tree[node<<1|1][3]);
}
ll query(int node,int l,int r,int L,int R,int id){
if(L<=l&&r<=R){
return tree[node][id];
}
int mid=(l+r)/2;
ll ma=-INF;
if(mid>=L) ma=max(ma,query(node<<1,l,mid,L,R,id));
if(mid<R) ma=max(ma,query(node<<1|1,mid+1,r,L,R,id));
return ma;
}
void update(int node,int pos,int l,int r,ll val,int id){
if(l==r){
tree[node][id]=val;
return ;
}
int mid=(l+r)/2;
if(mid>=pos) update(node<<1,pos,l,mid,val,id);
else update(node<<1|1,pos,mid+1,r,val,id);
tree[node][id]=max(tree[node<<1][id],tree[node<<1|1][id]);
}
signed main(){
scanf("%d%d",&n,&k);
for(int i=0;i<=n;i++){
dp[i][0]=dp[i][1]=-INF;
}
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
pre[i]=pre[i-1]+a[i];
}
build(1,0,n);
dp[0][0]=0;
// 第一颗线段树维护 dp[j][0]-pre[j]
// 第二棵线段树维护 dp[j][1]-pre[j]
// 第二棵线段树维护 dp[j][1]-2*pre[j]
for(int i=1;i<=n;i++){
if(i-k-1>=0){
dp[i][1]=max(pre[i]+query(1,0,n,0,i-k-1,1), pre[i]+query(1,0,n,0,i-k-1,2));
}
dp[i][0]=max(pre[i]+query(1,0,n,max(i-k,0),i-1,1),2*pre[i]+query(1,0,n,max(i-k,0),i-1,3));
update(1,i,0,n,dp[i][0]-pre[i],1);
update(1,i,0,n,dp[i][1]-pre[i],2);
update(1,i,0,n,dp[i][1]-2*pre[i],3);
}
printf("%lld\n",max(dp[n][1],dp[n][0]));
return 0;
}