二叉查找树
任何一个节点左子树的所有元素都小于这个节点,右子树的所有元素都大于这个节点
查找一个节点:从根节点开始,比他小就向左走,比他大就向右走
平衡树:解决二叉查找树的一些痛点。
二叉查找树的问题:它的形态并不固定,查找非常依赖于深度
通过一种叫做伸展的操作,让树的深度不那么深
那么什么是伸展?
伸展操作基于一个元操作:旋转(rotate)
如果一个节点之前被访问过,那么之后访问到它的几率会变大
通过旋转把这个点移到根,使下一次访问到它只需要o(1)的时间
Splay操作:把一个点旋转到根或者旋转到某个点下面
有x,y,z三个点,z是y的父亲,y是x的父亲
x,y,z三个点,如果在一条直线上就先转y再转x(让树的深度-1)
如果不在一条直线上,就转两次x
fa记录每一个点的父亲,ch记录节点的两个儿子(有可能没有)
函数:son(x)表示x是它父亲的左儿子还是右儿子
Rotate旋转(其实可以说抬升)
Splay将x的节点转到i的位置上
Cnt表示当前节点的数出现了多少次
Data表示当前节点的数是什么
Size表示当前节点及其子树中共有多少个数
pushup:左节点size+右节点size加上自己的cnt
插入:将x插入到rt节点中
如果x<data[rt],就往左子树插
如果x>data[rt],往右子树插
如果x=data[rt],就cnt[rt]++,size[rt]++,直接return掉
如果rt=0,也就是说这个位置没有出现过,就新建一个位置,rt=++nn(总的节点数),data[rt]=x,size[rt]=cnt[rt]=1,return掉
找前驱(小于x最大的数)和后继(大于x最小的数)
Getpre 按照子节点大小关系往两边找前驱
Getaux 同上,找后继
Getmn 找一颗子树中的最小值 只需要不停往左子树跳就可以了
Delete 删除节点
先找这个点
如果在左边,就往左边删除,如果在右边就往右边删除
如果找到了这个点,分情况。如果不只一个数,就cnt[rt]--,size[rt]--就可以了
否则,先把rt转到根,在右子树找最小的元素
如果没有右子树就让根变成他的左儿子
否则就让右儿子最小的节点转到根,因为最小的节点一定没有左儿子
Getk 查询x的排名
现在树中查找这个节点
找到之后,把他转到根,看他左子树有多少个数,+1就是3的排名0
Getkth 找到第k个数
如果左儿子的size+1<=k并且左儿子的size加上当前节点的数>k,就意味着第k个数一定是这个节点
如果k<左儿子的size+1,就往左节点找
否则就往右儿子找。注意要把左儿子的size和自己的cnt减去
int fa[N],ch[N][2]; int cnt[N]; int data[N]; int siz[N]; int son(int x) { return x==ch[fa[x]][1]; } void pushup(int rt) { siz[rt]=siz[ch[rt][0]]+siz[ch[rt][1]]+cnt[rt]; } void rotate(int x){ int y=fa[x],z=fa[y],b=son(x),c=son(y),a=ch[x][!b]; if(z) ch[z][c]=x; else root=x; fa[x]=z; if(a) fa[a]=y; ch[y][b]=a; ch[x][!b]=y;fa[y]=x; pushup(y);pushup(x); } void splay(int x,int i){ while(fa[x]!=i){ int y=fa[x],z=fa[y]; if(z==i){ rotate(x); }else{ if(son(x)==son(y)){ rotate(y);rotate(x); }else{ rotate(x);rotate(x); } } } } void insert(int &rt,int x){ if(rt==0){ rt=++nn; data[rt]=x; siz[rt]=cnt[rt]=1; return; } if(x==data[rt]){ cnt[rt]++; siz[rt]++; return; } if(x<data[rt]){ insert(ch[rt][0],x); fa[ch[rt][0]]=rt; pushup(rt); }else{ insert(ch[rt][1],x); fa[ch[rt][1]]=rt; pushup(rt); } } int getpre(int rt,int x){ int p=rt,ans; while(p){ if(x<=data[p]){ p=ch[p][0]; }else{ ans=p; p=ch[p][1]; } } return ans; } int getsuc(int rt,int x){ int p=rt,ans; while(p){ if(x>=data[p]){ p=ch[p][1]; }else{ ans=p; p=ch[p][0]; } } return ans; } int getmn(int rt){ int p=rt,ans=-1; while(p){ ans=p; p=ch[p][0]; } return ans; } void del(int rt,int x){ if(data[rt]==x){ if(cnt[rt]>1){ cnt[rt]--; siz[rt]--; }else{ splay(rt,0); int p=getmn(ch[rt][1]); if(p==-1){ root=ch[rt][0]; fa[ch[rt][0]]=0; }else{ splay(p,rt); root=p;fa[p]=0; ch[p][0]=ch[rt][0];fa[ch[rt][0]]=p; pushup(p); } } return; } if(x<data[rt]){ del(ch[rt][0],x); }else{ del(ch[rt][1],x); } pushup(rt); } int getk(int rt,int k){ if(data[rt]==k){ splay(rt,0); if(ch[rt][0]==0){ return 1; }else{ return siz[ch[rt][0]]+1; } } if(k<data[rt]) return getk(ch[rt][0],k); if(data[rt]<k) return getk(ch[rt][1],k); } int getkth(int rt,int k){ int l=ch[rt][0]; if(siz[l]+1<=k&&k<=siz[l]+cnt[rt]) return data[rt]; if(k<siz[l]+1) return getkth(l,k); else return getkth(ch[rt][1],k-siz[l]-cnt[rt]); }