我们先将所有连续正段和连续负段合并,那么选取所有正段一定是最优的,但是选取的段数有可能超过$m$段,这时我们就需要合并。
1.选取不在两端的一个负段与它两边正段合并,块数减少$1$,子段和减少$|v[i]|$。
2.选取一个正段,将其删除,并与其左右负段合并,块数减少$1$,子段和减少$|v[i]|$。
所以我们将每段按照$|v[i]|$大小排序,每次贪心合并至段数小于等于$m$。
1 #include <iostream> 2 #include <stdio.h> 3 #include <queue> 4 #include <algorithm> 5 using namespace std; 6 #define pa pair<int,int> 7 #define maxn 100001 8 #define INF 0x3f3f3f3f 9 inline int read() 10 { 11 int s=0,f=1; 12 char ch=getchar(); 13 while(ch<'0'||ch>'9') 14 { 15 if(ch=='-') 16 f=-1; 17 ch=getchar(); 18 } 19 while(ch>='0'&&ch<='9') 20 s=(s<<1)+(s<<3)+ch-'0',ch=getchar(); 21 return s*f; 22 } 23 int pr[maxn],nx[maxn],s[maxn]; 24 int n,k,ans; 25 priority_queue<pa,vector<pa>,greater<pa> >q; 26 int main() 27 { 28 n=read(); 29 k=read(); 30 int x,y; 31 for(int i=1;i<=n;i++) 32 { 33 x=read(); 34 s[i]=x-y; 35 y=x; 36 if(i!=1) 37 q.push(pa(s[i],i)); 38 pr[i]=i-1; 39 nx[i]=i+1; 40 } 41 pr[2]=0; 42 nx[n]=0; 43 while(k--) 44 { 45 while(q.top().first!=s[q.top().second]) 46 q.pop(); 47 int u=q.top().second; 48 int l=pr[u],r=nx[u]; 49 q.pop(); 50 ans+=s[u]; 51 s[u]=l&&r?s[l]+s[r]-s[u]:INF; 52 pr[nx[u]=nx[r]]=u; 53 nx[pr[u]=pr[l]]=u; 54 s[l]=s[r]=INF; 55 q.push(pa(s[u],u)); 56 } 57 printf("%d",ans); 58 }