一道splay综合大板子题。
题面:https://www.lydsy.com/JudgeOnline/problem.php?id=1500
下面是题解:
首先对每个点维护这些量:
1.两个儿子(ch[2])
2.父节点(fa)
3.当前点权值(vl)和子树权值(sum)
4.修改标记(xg),反转标记(fz)。
5.mx,mxl,mxr(不清楚的建议做小白逛公园)。
题中还有一个条件,即:
100%的数据中,任何时刻数列中最多含有 500 000 个数。
这就解决了空间问题。我们可以用一个队列记录有多少点可以回收,这样就可以节省大量的空间。
最后一点:本题中涉及到max值,区间反转后需要update。
代码:
#include<queue> #include<cstdio> #include<algorithm> using namespace std; #define N 1000050 int n,m,a[N],rt,tot,pos,cnt,id[N]; queue<int>que; char ch[15]; struct Splay { int ch[2]; int xg,fz,fa; int vl,sum,siz,mx,mxl,mxr; }tr[N]; void update(int u) { int l= tr[u].ch[0] , r = tr[u].ch[1]; tr[u].sum = tr[l].sum+tr[r].sum+tr[u].vl; tr[u].siz = tr[l].siz+tr[r].siz+1; tr[u].mx = max(max(tr[l].mx,tr[r].mx),tr[l].mxr+tr[r].mxl+tr[u].vl); tr[u].mxl = max(tr[l].mxl,tr[l].sum+tr[u].vl+tr[r].mxl); tr[u].mxr = max(tr[r].mxr,tr[r].sum+tr[u].vl+tr[l].mxr); } void pushdown(int u) { int l = tr[u].ch[0],r = tr[u].ch[1]; if(tr[u].xg) { tr[u].xg = tr[u].fz = 0; if(l)tr[l].xg = 1,tr[l].vl = tr[u].vl,tr[l].sum = tr[l].siz*tr[l].vl; if(r)tr[r].xg = 1,tr[r].vl = tr[u].vl,tr[r].sum = tr[r].siz*tr[r].vl; if(tr[u].vl>0) { if(l)tr[l].mx=tr[l].mxl=tr[l].mxr=tr[l].sum; if(r)tr[r].mx=tr[r].mxl=tr[r].mxr=tr[r].sum; }else { if(l)tr[l].mx=tr[l].vl,tr[l].mxl=tr[l].mxr=0; if(r)tr[r].mx=tr[r].vl,tr[r].mxl=tr[r].mxr=0; } } if(tr[u].fz) { tr[u].fz = 0; tr[l].fz^=1,tr[r].fz^=1; swap(tr[l].ch[0],tr[l].ch[1]); swap(tr[r].ch[0],tr[r].ch[1]); swap(tr[l].mxl,tr[l].mxr); swap(tr[r].mxl,tr[r].mxr); } } void rotate(int x) { int y = tr[x].fa; int z = tr[y].fa; int k = (tr[y].ch[1]==x); tr[tr[x].ch[k^1]].fa = y,tr[y].ch[k] = tr[x].ch[k^1]; tr[x].ch[k^1] = y,tr[y].fa = x; tr[x].fa = z,tr[z].ch[tr[z].ch[1]==y]=x; update(y);update(x); } void splay(int u,int goal) { while(tr[u].fa!=goal) { int y = tr[u].fa; int z = tr[y].fa; if(z!=goal) ((tr[y].ch[1]==u)^(tr[z].ch[1]==y))?rotate(u):rotate(y); rotate(u); } if(!goal)rt=u; } void build(int l,int r,int f) { if(l>r)return ; int mid = (l+r)>>1,u = id[mid],fa = id[f]; if(l==r) { tr[u].siz=1; if(a[l]>0)tr[u].mx=tr[u].mxl=tr[u].mxr=a[l]; else tr[u].mx=a[l],tr[u].mxl=tr[u].mxr=0; }else { build(l,mid-1,mid); build(mid+1,r,mid); } tr[u].vl = a[mid]; tr[u].fa = fa; update(u); tr[fa].ch[mid>=f] = u; } int find(int x,int k) { pushdown(x); int t = tr[tr[x].ch[0]].siz; if(k<=t)return find(tr[x].ch[0],k); else if(k==t+1)return x; else return find(tr[x].ch[1],k-1-t); } void insert(int k,int tt) { for(int i=1;i<=tt;i++) { scanf("%d",&a[i]); } for(int i=1;i<=tt;i++) { if(!que.empty()) { id[i]=que.front(); que.pop(); }else { id[i]=++cnt; } } build(1,tt,0); int rt0 = id[(1+tt)>>1]; int l = find(rt,k+1); int r = find(rt,k+2); splay(l,0); splay(r,l); tr[r].ch[0]=rt0; tr[rt0].fa=r; update(r); update(rt); } void rip(int x) { if(!x)return ; int l = tr[x].ch[0],r = tr[x].ch[1]; rip(l); rip(r); que.push(x); tr[x].xg=tr[x].fz=tr[x].vl=tr[x].sum=0; tr[x].siz=tr[x].mx=tr[x].mxl=tr[x].mxr=tr[x].ch[0]=tr[x].ch[1]=0; } int deal(int l,int r) { l = find(rt,l); r = find(rt,r); splay(l,0); splay(r,l); return tr[r].ch[0]; } void erase(int l,int r) { int x = deal(l,r),y=tr[x].fa; rip(x); tr[y].ch[0]=0; update(y); update(rt); } void make_same(int l,int r,int k) { int x = deal(l,r),y=tr[x].fa; tr[x].vl=k,tr[x].xg=1; tr[x].sum = tr[x].siz*k; if(k>0)tr[x].mx=tr[x].mxl=tr[x].mxr=tr[x].sum; else tr[x].mx=k,tr[x].mxl=tr[x].mxr=0; update(y); update(rt); } void rever(int l,int r) { int x = deal(l,r); if(!tr[x].xg) { tr[x].fz^=1; swap(tr[x].ch[0],tr[x].ch[1]); swap(tr[x].mxl,tr[x].mxr); update(tr[x].fa),update(rt); } } int get_sum(int l,int r) { int x = deal(l,r); return tr[x].sum; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++)scanf("%d",&a[i+1]); for(int i=1;i<=n+2;i++)id[i]=i; tr[0].mx=a[1]=a[n+2]=-0x3f3f3f3f; build(1,n+2,0); cnt = n+2; rt = (n+3)>>1; for(int k,i=1;i<=m;i++) { scanf("%s",ch); if(ch[0]=='I') { scanf("%d%d",&pos,&tot); insert(pos,tot); }else if(ch[0]=='D') { scanf("%d%d",&pos,&tot); erase(pos,pos+tot+1); }else if(ch[2]=='K') { scanf("%d%d%d",&pos,&tot,&k); make_same(pos,pos+tot+1,k); }else if(ch[0]=='R') { scanf("%d%d",&pos,&tot); rever(pos,pos+tot+1); }else if(ch[0]=='G') { scanf("%d%d",&pos,&tot); printf("%d ",get_sum(pos,pos+tot+1)); }else { printf("%d ",tr[rt].mx); } } return 0; }