题意:求一个序列的最大的(区间最小值*区间和)
线段树做法:用单调栈求出每个数两边比它大的左右边界,然后用线段树求出每段区间的和sum、最小前缀lsum、最小后缀rsum,枚举每个数a[i],设以a[i]为最小值的区间为[l,r]
若a[i]>0,则最优解就是a[i]*([l,r]的区间和),因为[l,r]上的数都比a[i]大。
若a[i]<0,则最优解是a[i]*([l,i-1]上的最小后缀+a[i]+[i+1,r]上的最小前缀),在线段树上查询即可。
复杂度$O(nlogn)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=5e5+10,inf=0x3f3f3f3f; 5 int a[N],n,sta[N],L[N],R[N],tp; 6 #define ls (u<<1) 7 #define rs (u<<1|1) 8 #define mid ((l+r)>>1) 9 struct D {ll sum,lsum,rsum;} s[N<<2]; 10 D mg(D a,D b) { 11 D t= {0,0}; 12 t.sum=a.sum+b.sum; 13 t.lsum=min(a.lsum,a.sum+b.lsum); 14 t.rsum=min(b.rsum,b.sum+a.rsum); 15 return t; 16 } 17 void build(int u=1,int l=1,int r=n) { 18 if(l==r) {s[u].sum=a[l],s[u].lsum=min((ll)a[l],0ll),s[u].rsum=min((ll)a[l],0ll); return;} 19 build(ls,l,mid),build(rs,mid+1,r),s[u]=mg(s[ls],s[rs]); 20 } 21 void qry(int L,int R,D& x,int u=1,int l=1,int r=n) { 22 if(l>=L&&r<=R) {x=mg(x,s[u]); return;} 23 if(l>R||r<L)return; 24 qry(L,R,x,ls,l,mid),qry(L,R,x,rs,mid+1,r); 25 } 26 int main() { 27 scanf("%d",&n); 28 a[0]=a[n+1]=~inf; 29 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 30 sta[tp=0]=0; 31 for(int i=1; i<=n; ++i) { 32 for(; a[sta[tp]]>=a[i]; --tp); 33 L[i]=sta[tp]+1,sta[++tp]=i; 34 } 35 sta[tp=0]=n+1; 36 for(int i=n; i>=1; --i) { 37 for(; a[sta[tp]]>=a[i]; --tp); 38 R[i]=sta[tp]-1,sta[++tp]=i; 39 } 40 build(); 41 ll ans=0; 42 for(int i=1; i<=n; ++i) { 43 if(a[i]>0) { 44 D t= {0,0}; 45 qry(L[i],R[i],t); 46 ans=max(ans,a[i]*t.sum); 47 } else if(a[i]<0) { 48 ll x=0; 49 D t= {0,0}; 50 qry(L[i],i,t); 51 x+=t.rsum; 52 t= {0,0}; 53 qry(i,R[i],t); 54 x+=t.lsum; 55 x-=a[i]; 56 ans=max(ans,a[i]*x); 57 } 58 } 59 printf("%lld ",ans); 60 return 0; 61 }
笛卡尔树做法:对整个序列建立笛卡尔树,用和线段树相同的方法求出每个结点的子树所代表区间的sum,lsum,rsum,枚举每个结点,如果是正数则乘上该结点的sum,如果是负数则乘上该结点的(左儿子的rsum+右儿子的lsum+结点本身的值)即可。
复杂度$O(n)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=5e5+10,inf=0x3f3f3f3f; 5 int n,a[N],ls[N],rs[N],sta[N],tp; 6 ll sum[N],lsum[N],rsum[N]; 7 void build() { 8 a[n+1]=~inf,sta[tp=0]=n+1; 9 for(int i=1; i<=n; ++i) { 10 for(; a[i]<a[sta[tp]]; --tp); 11 ls[i]=rs[sta[tp]],rs[sta[tp]]=i,sta[++tp]=i; 12 } 13 } 14 void dfs(int u) { 15 if(!u)return; 16 dfs(ls[u]),dfs(rs[u]); 17 sum[u]=sum[ls[u]]+a[u]+sum[rs[u]]; 18 lsum[u]=min(lsum[ls[u]],sum[ls[u]]+a[u]+lsum[rs[u]]); 19 rsum[u]=min(rsum[rs[u]],sum[rs[u]]+a[u]+rsum[ls[u]]); 20 } 21 int main() { 22 scanf("%d",&n); 23 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 24 build(),dfs(rs[n+1]); 25 ll ans=0; 26 for(int i=1; i<=n; ++i) { 27 if(a[i]>0)ans=max(ans,a[i]*sum[i]); 28 else if(a[i]<0)ans=max(ans,a[i]*(rsum[ls[i]]+a[i]+lsum[rs[i]])); 29 } 30 printf("%lld ",ans); 31 return 0; 32 }