CF573E
题意概要
给出一个长度为(n)的数列,从中选出一个子序列(b[1...m])(可以为空)
使得$$ sum_{i=1}^m{b_i*i}$$最大,输出这个最大值。
其中(nle10^5)
题解
设(dp_{i,j})表示前(i)个数选择(j)个数的最大值
那么,转移方程则为:
[dp_{i,j}=max(dp_{i-1,j},dp_{i-1,j-1}+j*a_i)
]
于是我们就得到了一个(n^2)的做法
我们考虑优化这个式子。
经(dalao)证明,发现总有存在一个分界线,这之前的取前者,这之后的取后者
大佬的证明在这里
我们二分分界线,然后用平衡树维护就好了
顺带一提,我今日方知(splay)没事多$ splay $几下还会变快
代码
#include<bits/stdc++.h>
#include<windows.h>
#define lch c[x][0]
#define rch c[x][1]
using namespace std;
typedef long long ll;
const int sz=1e5+7;
int n;
ll v,ans;
int rt,cnt;
int f[sz];
int c[sz][2];
int siz[sz];
ll val[sz],tag1[sz],tag2[sz];
inline int newnode(ll v){
int x=++cnt;
val[x]=v;
siz[x]=1;
return x;
}
inline void pushup(int x){
siz[x]=siz[lch]+siz[rch]+1;
}
inline void add(int x,ll tg1,ll tg2){
val[x]+=siz[lch]*tg1+tg2;
tag1[x]+=tg1;
tag2[x]+=tg2;
}
inline void pd(int x){
if(tag1[x]==0&&tag2[x]==0) return;
if(lch) add(lch,tag1[x],tag2[x]);
if(rch) add(rch,tag1[x],tag2[x]+(siz[lch]+1)*tag1[x]);
tag1[x]=0;
tag2[x]=0;
}
inline void pushdn(int x){
if(f[x]) pushdn(f[x]);
pd(x);
}
inline int get(int x){
return c[f[x]][1]==x;
}
inline void dfs(int x){
ans=max(ans,val[x]);
pd(x);
if(lch) dfs(lch);
if(rch) dfs(rch);
}
inline void rotate(int x){
int y=f[x],z=f[y],k=get(x),w=c[x][!k];
if(z) c[z][get(y)]=x;c[x][!k]=y;c[y][k]=w;
if(w) f[w]=y;if(y) f[y]=x;f[x]=z;
pushup(y);
}
inline void splay(int x,int t){
pushdn(x);
while(f[x]!=t){
int y=f[x];
if(f[y]!=t) rotate(get(x)^get(y)?x:y);
rotate(x);
}
pushup(x);
if(t==0) rt=x;
}
inline int find(int k){
int x=rt;
while(1){
pd(x);
if(siz[lch]+1==k) return x;
if(k<=siz[lch]) x=lch;
else k-=siz[lch]+1,x=rch;
}
}
inline void insert(int k,ll v){
int x=find(k-1);splay(x,0);
int y=find(k);splay(y,x);
c[y][0]=newnode(v);
f[c[y][0]]=y;
pushup(y);
pushup(x);
}
inline void modify(int l,int r,ll a,ll b){
int x=find(l-1);splay(x,0);
int y=find(r+1);splay(y,x);
add(c[y][0],a,b);
pushup(y);
pushup(x);
}
int main(){
scanf("%d",&n);
rt=newnode(0);
c[rt][0]=newnode(INT_MIN);
c[rt][1]=newnode(INT_MIN);
f[c[rt][0]]=f[c[rt][1]]=rt;
pushup(rt);
for(int i=1;i<=n;i++){
scanf("%lld",&v);
int l=2,r=i+1;
while(l<r){
int mid=(l+r)>>1;
if(val[find(mid)]+(mid-1)*v>=val[find(mid+1)]) r=mid;
else l=mid+1;
splay(find(mid),0);
}
int x=find(l);
ll y=val[x];
modify(l,i+1,v,v*(l-1));
insert(l,y);
}
dfs(rt);
printf("%lld
",ans);
}