题目大意:
求一个序列的第k大的子串和。
题解:
对于一个右端点找最优的左端点,扔进堆里。
每次取堆顶,将这个右端点可以选择的左端点的区间分成两段,扔进堆里,重复k次。
现在需要对于一个固定的右端点,左端点在一个区间里,求最大值。
可持久化线段树上区间修改,不用标记永久化也可以过。
代码:
#include<cstdio> #include<algorithm> #include<map> #include<queue> #define mp make_pair #define pr pair<long long,int> #define prr pair<pr,pr> #define fr first #define sc second using namespace std; int n,k,cnt,ls[10000005],rs[10000005],root[200005]; long long tag[10000005]; priority_queue<prr> q; map<int,int> pre; struct node{ long long val; int id; }tree[10000005]; void build(int &x,int l,int r){ x=++cnt; tree[x]=(node){0,l}; if (l==r) return; int mid=(l+r)>>1; build(ls[x],l,mid); build(rs[x],mid+1,r); } void add(int &now,int pre,long long key){ now=++cnt; ls[now]=ls[pre],rs[now]=rs[pre],tree[now]=tree[pre],tag[now]=tag[pre]+key; tree[now].val+=key; } void push_down(int x){ if (!tag[x]) return; add(ls[x],ls[x],tag[x]); add(rs[x],rs[x],tag[x]); tag[x]=0; } void insert(int &now,int pre,int l,int r,int x,int y,int key){ if (l>y || r<x) return; if (l>=x && r<=y){ add(now,pre,key); return; } push_down(pre); now=++cnt; ls[now]=ls[pre],rs[now]=rs[pre],tree[now]=tree[pre]; int mid=(l+r)>>1; insert(ls[now],ls[pre],l,mid,x,y,key); insert(rs[now],rs[pre],mid+1,r,x,y,key); if (tree[rs[now]].val>tree[ls[now]].val) tree[now]=tree[rs[now]]; else tree[now]=tree[ls[now]]; } node query(int now,int l,int r,int x,int y){ if (!now) return (node){-1ll<<60,0}; if (l>y || r<x) return (node){-1ll<<60,0}; if (l>=x && r<=y) return tree[now]; push_down(now); int mid=(l+r)>>1; node max1=query(ls[now],l,mid,x,y); node max2=query(rs[now],mid+1,r,x,y); if (max1.val>max2.val) return max1; else return max2; } void insert(int x,int l,int r){ if (l>r) return; node sum=query(x,1,n,l,r); q.push(mp(mp(sum.val,x),mp(l,r))); } int main(){ scanf("%d%d",&n,&k); build(root[0],1,n); for (int i=1; i<=n; i++){ int x; scanf("%d",&x); insert(root[i],root[i-1],1,n,pre[x]+1,i,x); pre[x]=i; insert(root[i],1,i); } long long sum; while (k--){ sum=q.top().fr.fr; int id=q.top().fr.sc,l=q.top().sc.fr,r=q.top().sc.sc; q.pop(); int mid=query(id,1,n,l,r).id; insert(id,l,mid-1); insert(id,mid+1,r); } printf("%lld ",sum); return 0; }