题目链接
题意分析
其实关于这道题的话 最下面的题意概括已经说的很明了
关于这道题 我们的第一想法是DP
\(f_i\)表示已经分好了前\(i\)个数字的最小代价 我们枚举k作为一段\([k,j]\)的开头进行转移
\[f_i=min\{f_{k-1}+\max_{j=k}^ih_j\}(\sum_{j=k}^ih_j≤m)
\]
暴力的话 \(O(n^2)\) 而且\(n≤10^5\) 所以我们需要考虑优化
我们要提一个十分重要的结论
当前枚举到了\(i\)的时候 如果把\(i\)作为定点 同时定义
\[hm_k=\max_{j=k}^ih_j\ \ \ \ z_k=f_{k-1}+\max_{j=k}^ih_j
\]
为了防止混乱 一定要时时刻刻明白一点 \(f_i\)是以\(i\)为右端点 而\(hm_k\)以及\(z_k\)都是以\(k\)为左端点
首先我们可以明白 从\(1\)到\(i\) \(hm_k\)是单减不增的
所以 我们加入一个\(h_i\)之后的话 存在一个临界点x 使得\(hm_x≥h_i\)
这个临界点\(x\)之后的\(hm_k\)都会被更新
由于存在单调性 所以我们可以二分出这个位置 然后使用区间赋值操作进行修改
同时对于当前的\(i\) 其答案就是
\[\max_{j=1}^iz_j
\]
等等! 二分 区间赋值 区间查询最值 这不就是线段树的骚操作?
我们使用线段树维护三样东西 \(f_i\)的最小值 \(hm_i\)的最大值 \(f_i+hm_{i+1}\)的最小值
首先 我们使用尺取法枚举当前的合法区间\([l,r]\) 保证
\[\sum_{j=l}^rh_j≤m
\]
同时按照上面所说 加入\(h_r\) 对于之前维护的答案产生影响并加以修改
然后将\(h_{r+1}\)插入线段树中 用以作为\(f_r+hm_{r+1}\)维护答案
最后到\(n\)的时候 就是查询
\[\max_{i=1}^n{f_i+hm_i}
\]
CODE:
#include<bits/stdc++.h>
#define N 800010
#define INF 2147483600
using namespace std;
int n,m;
struct Node
{
int m1,m2,m3;
int tag;
}tre[N];
int num[N];
void pushup(int now)
{
tre[now].m1=min(tre[now<<1].m1,tre[now<<1|1].m1);
tre[now].m2=max(tre[now<<1].m2,tre[now<<1|1].m2);
tre[now].m3=min(tre[now<<1].m3,tre[now<<1|1].m3);
}
void down(int now)
{
if(tre[now].tag)
{
tre[now<<1].m2=tre[now].tag;
tre[now<<1].m3=tre[now<<1].m1+tre[now].tag;
tre[now<<1].tag=tre[now].tag;
tre[now<<1|1].m2=tre[now].tag;
tre[now<<1|1].m3=tre[now<<1|1].m1+tre[now].tag;
tre[now<<1|1].tag=tre[now].tag;
tre[now].tag=0;
}
}
void insert(int now,int le,int ri,int pos,int fi,int hi)
{
if(le==ri)
{
tre[now].m1=fi;
tre[now].m2=hi;
tre[now].m3=fi+hi;
return;
}
int mid=(le+ri)>>1;down(now);
if(pos<=mid) insert(now<<1,le,mid,pos,fi,hi);
else insert(now<<1|1,mid+1,ri,pos,fi,hi);
pushup(now);
}
void update(int now,int lenow,int rinow,int le,int ri,int d)
{
if(le<=lenow&&rinow<=ri)
{
tre[now].m2=d;
tre[now].m3=tre[now].m1+d;
tre[now].tag=d;
return;
}
int mid=(lenow+rinow)>>1;down(now);
if(le<=mid) update(now<<1,lenow,mid,le,ri,d);
if(mid<ri) update(now<<1|1,mid+1,rinow,le,ri,d);
pushup(now);
}
int getdp(int now,int lenow,int rinow,int le,int ri)
{
if(le<=lenow&&rinow<=ri) return tre[now].m3;
int mid=(lenow+rinow)>>1,tmp=INF;down(now);
if(le<=mid) tmp=min(tmp,getdp(now<<1,lenow,mid,le,ri));
if(mid<ri) tmp=min(tmp,getdp(now<<1|1,mid+1,rinow,le,ri));
return tmp;
}
int geth(int now,int lenow,int rinow,int le,int ri)
{
if(le<=lenow&&rinow<=ri) return tre[now].m2;
int mid=(lenow+rinow)>>1,tmp=0;down(now);
if(le<=mid) tmp=max(tmp,geth(now<<1,lenow,mid,le,ri));
if(mid<ri) tmp=max(tmp,geth(now<<1|1,mid+1,rinow,le,ri));
return tmp;
}
int getat(int lenow,int rinow,int k)
{
int le=lenow,ri=rinow+1,ans=rinow+1;
while(le<ri)
{
// printf("now is %d %d\n",le,ri);
int mid=(le+ri)>>1;
if(geth(1,1,n,mid,rinow)<k) ans=ri=mid;
else le=mid+1;
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) scanf("%d",&num[i]);
int nowsum=0;insert(1,1,n,1,0,num[1]);
for(int l=1,r=1;r<=n;++r)
{
nowsum+=num[r];
// printf("now is %d\n",r);
while(nowsum>m) nowsum-=num[l++];
int tmp=getat(l,r-1,num[r]);
if(tmp<r) update(1,1,n,tmp,r-1,num[r]);
int nowtmp=getdp(1,1,n,l,r);
if(r==n) printf("%d\n",nowtmp);
else insert(1,1,n,r+1,nowtmp,num[r+1]);
}
return 0;
}