找到一份比较好的板子,链接https://blog.csdn.net/crazy_ac/article/details/8034190
#include<cstdio> #include<cstdlib> const int inf = ~0u>>2; #define L ch[x][0] #define R ch[x][1] #define KT (ch[ ch[rt][1] ][0]) const int maxn = 500010; int lim; struct SplayTree { int sz[maxn]; int ch[maxn][2]; int pre[maxn]; int rt,top; inline void up(int x){ sz[x] = cnt[x] + sz[ L ] + sz[ R ]; } inline void Rotate(int x,int f){ int y=pre[x]; ch[y][!f] = ch[x][f]; pre[ ch[x][f] ] = y; pre[x] = pre[y]; if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] =x; ch[x][f] = y; pre[y] = x; up(y); } inline void Splay(int x,int goal){//将x旋转到goal的下面 while(pre[x] != goal){ if(pre[pre[x]] == goal) Rotate(x , ch[pre[x]][0] == x); else { int y=pre[x],z=pre[y]; int f = (ch[z][0]==y); if(ch[y][f] == x) Rotate(x,!f),Rotate(x,f); else Rotate(y,f),Rotate(x,f); } } up(x); if(goal==0) rt=x; } inline void RTO(int k,int goal){//将第k位数旋转到goal的下面 int x=rt; while(sz[ L ] != k-1) { if(k < sz[ L ]+1) x=L; else { k-=(sz[ L ]+1); x = R; } } Splay(x,goal); } inline void vist(int x){ if(x){ printf("结点%2d : 左儿子 %2d 右儿子 %2d val:%2d sz=%d cnt:%d ",x,L,R,val[x],sz[x],cnt[x]); vist(L); vist(R); } } void debug() { puts(""); vist(rt); puts(""); } inline void Newnode(int &x,int c,int f){ x=++top; L = R = 0; pre[x] = f; sz[x]=1; cnt[x]=1; val[x] = c; } inline void init(){ ch[0][0]=ch[0][1]=pre[0]=sz[0]=0; rt=top=0; cnt[0]=0; } inline void Insert(int &x,int key,int f){ if(!x) { Newnode(x,key,f); Splay(x,0);//注意插入完成后splay return ; } if(key==val[x]){ cnt[x]++; sz[x]++; Splay(x,0);//注意插入完成后splay return ; }else if(key<val[x]) { Insert(L,key,x); } else { Insert(R,key,x); } up(x); } void Del_root(){//删除根节点 int t=rt; if(ch[rt][1]) { rt=ch[rt][1]; RTO(1,0); ch[rt][0]=ch[t][0]; if(ch[rt][0]) pre[ch[rt][0]]=rt; } else rt=ch[rt][0]; pre[rt]=0; up(rt); } void findpre(int x,int key,int &ans){//找前驱节点 if(!x) return ; if(val[x] <= key){ ans=x; findpre(R,key,ans); } else findpre(L,key,ans); } void findsucc(int x,int key,int &ans){//找后继节点 if(!x) return ; if(val[x]>=key) { ans=x; findsucc(L,key,ans); } else findsucc(R,key,ans); } inline int find_kth(int x,int k){ //第k小的数 if(k<sz[L]+1) { return find_kth(L,k); }else if(k > sz[ L ] + cnt[x] ) return find_kth(R,k-sz[L]-cnt[x]); else{ Splay(x,0); return val[x]; } } int find(int x,int key){ if(!x) return 0; else if(key < val[x]) return find(L,key); else if(key > val[x]) return find(R,key); else return x; } int getmin(int x){ while(L) x=L; return val[x]; } int getmax(int x){ while(R) x=R; return val[x]; } //确定key的排名 int getrank(int x,int key,int cur){//cur:当前已知比要求元素(key)小的数的个数 if(key == val[x]) return sz[L] + cur + 1; else if(key < val[x]) getrank(L,key,cur); else getrank(R,key,cur+sz[L]+cnt[rt]); } int get_lt(int x,int key){//小于key的数的个数 lt:less than if(!x) return 0; if(val[x]>=key) return get_lt(L,key); return cnt[x]+sz[L]+get_lt(R,key); } int get_mt(int x,int key){//大于key的数的个数 mt:more than if(!x) return 0; if(val[x]<=key) return get_mt(R,key) ; return cnt[x]+sz[R]+get_mt(L,key); } void del(int &x,int f){//删除小于lim的所有的数所在的节点 if(!x) return ; if(val[x]>=lim){ del(L,x); } else { x=R; pre[x]=f; if(f==0) rt=x; del(x,f); } if(x) up(x); } inline void update(){ del(rt,0); } int get_mt(int key) { return get_mt(rt,key); } int get_lt(int key) { return get_lt(rt,key); } void insert(int key) { Insert(rt,key,0); } void Delete(int key) { int node=find(rt,key); Splay(node,0); cnt[rt]--; if(!cnt[rt])Del_root(); } int kth(int k) { return find_kth(rt,k); } int cnt[maxn]; int val[maxn]; int lim; }spt;