Description
Input
第1行,一个整数N;
第2~n+1行,每行一个整数表示序列a。
Output
输出答案对10^9取模后的结果。
Sample Input
4
2
4
1
4
2
4
1
4
Sample Output
109
【数据范围】
N <= 500000
1 <= a_i <= 10^8
【数据范围】
N <= 500000
1 <= a_i <= 10^8
HINT
Source
真是********
考虑分治法,问题转化为如何解决跨过中点的答案。
对于一个序列[i,j],倒序枚举左端点i,则整个的maxv、minv的值有四种可能:
1.均在左一半取。
计算满足条件的右端最多到哪,记作cur,因为单调性,可以O(1)求。
ans+=maxvl*minvl*((mid+1-i+1)+(mid+2-i+1)+···+(cur-i+1))
变形,用等差数列求和就可以O(1)算了。
2.maxv在左取,minv在右取。
计算满足条件的右端最多到哪,记作curmn,因为单调性,可以O(1)求。
ans+=maxvl*[minvj*(j-i+1)+minvj+1*(j+1-i+1)+···+minvcurmn*(curmn-i+1)]
变形,维护一下minvj*j和minvj的前缀和就可以O(1)算了。
3.minv在右取,maxv在左取。
计算满足条件的右端最多到哪,记作curmx,因为单调性,可以O(1)求。
ans+=minvl*[maxvj*(j-i+1)+maxvj+1*(j+1-i+1)+···+maxvcurmx*(curmx-i+1)]
变形,维护一下maxvj*j和maxvj的前缀和就可以O(1)算了。
4.均在右一半取。
ans+=maxvj*minvj*(j-i+1)+···+maxvr*minvr*(r-i+1)。
变形,维护一下maxvj*minvj*j和maxvj*minvj的前缀和就可以O(1)算了。
说起来还真是简单呢!
#include<cstdio> #include<cctype> #include<queue> #include<cmath> #include<cstring> #include<algorithm> #define rep(i,s,t) for(int i=s;i<=t;i++) #define dwn(i,s,t) for(int i=s;i>=t;i--) #define ren for(int i=first[x];i!=-1;i=next[i]) using namespace std; const int BufferSize=1<<16; char buffer[BufferSize],*head,*tail; inline char Getchar() { if(head==tail) { int l=fread(buffer,1,BufferSize,stdin); tail=(head=buffer)+l; } return *head++; } inline int read() { int x=0,f=1;char c=Getchar(); for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1; for(;isdigit(c);c=Getchar()) x=x*10+c-'0'; return x*f; } const int maxn=500010; const int mod=1000000000; typedef long long ll; struct Pair { ll mx,mn; }q1[maxn],q2[maxn]; Pair update(Pair a,ll v) { a.mx=max(a.mx,v);a.mn=min(a.mn,v); return a; } int n; ll A[maxn],ans,S1[maxn],S2[maxn],S3[maxn],SS1[maxn],SS2[maxn],SS3[maxn]; ll sumlen(int b,int s,int t) { if(s>t) return 0; return (ll)(s-b+1+t-b+1)*(t-s+1)/2%mod; } void solve(int l,int r) { if(r-l+1<=10) { rep(i,l,r) { ll mx=A[i],mn=A[i]; rep(j,i,r) { mx=max(mx,A[j]);mn=min(mn,A[j]); (ans+=(j-i+1)*mx*mn)%=mod; } } return; } int mid=l+r>>1;solve(l,mid);solve(mid+1,r); q1[mid]=(Pair){A[mid],A[mid]};dwn(i,mid-1,l) q1[i]=update(q1[i+1],A[i]); q2[mid+1]=(Pair){A[mid+1],A[mid+1]};rep(i,mid+2,r) q2[i]=update(q2[i-1],A[i]); q1[l-1]=q2[r+1]=(Pair){1e9,-1e9}; S1[mid]=S2[mid]=S3[mid]=SS1[mid]=SS2[mid]=SS3[mid]=0; rep(i,mid+1,r) { S1[i]=(S1[i-1]+q2[i].mx*(i-mid))%mod; SS1[i]=(SS1[i-1]+q2[i].mx)%mod; S2[i]=(S2[i-1]+q2[i].mn*(i-mid))%mod; SS2[i]=(SS2[i-1]+q2[i].mn)%mod; S3[i]=(S3[i-1]+q2[i].mx*q2[i].mn%mod*(i-mid))%mod; SS3[i]=(SS3[i-1]+q2[i].mn*q2[i].mx)%mod; } int cur=mid+1,curmx=mid+1,curmn=mid+1; dwn(i,mid,l) { while(q2[cur].mx<=q1[i].mx&&q2[cur].mn>=q1[i].mn) cur++; while(q2[curmx].mx<=q1[i].mx) curmx++; while(q2[curmn].mn>=q1[i].mn) curmn++; ll t=ans; (ans+=sumlen(i,mid+1,cur-1)*q1[i].mx%mod*q1[i].mn%mod)%=mod; if(curmn-1>=curmx) (ans+=q1[i].mn*((S1[curmn-1]-S1[curmx-1]+mod+(SS1[curmn-1]-SS1[curmx-1]+mod)*(mid-i+1))%mod))%=mod; if(curmx-1>=curmn) (ans+=q1[i].mx*((S2[curmx-1]-S2[curmn-1]+mod+(SS2[curmx-1]-SS2[curmn-1]+mod)*(mid-i+1))%mod))%=mod; int p=max(curmx,curmn); if(p<=r) (ans+=S3[r]-S3[p-1]+mod+(SS3[r]-SS3[p-1]+mod)*(mid-i+1))%=mod; } } int main() { n=read();rep(i,1,n) A[i]=read(); solve(1,n);printf("%lld ",ans); return 0; }