• luogu P3369 【模板】普通平衡树(splay)


    嘟嘟嘟


    突然觉得splay挺有意思,唯一不足的是这几天是一天一道,debug到崩溃。


    做了几道平衡树基础题后,对这题有莫名的自信,还算愉快的敲完了代码后,发现样例都过不去,然后就陷入了无限的debug环节了……算了,伤心的事就别再提了。


    说一下这题怎么做:
    1.插入
    不说了

    void insert(int x)
    {
      int now = root, f = 0;
      while(now && t[now].val != x) f = now, now = t[now].ch[x > t[now].val];
      if(now) t[now].cnt++;
      else
        {
          now = ++ncnt;
          if(f) t[f].ch[x > t[f].val] = now;
          t[now].fa = f;
          t[now].ch[0] = t[now].ch[1] = 0;	
          t[now].siz = t[now].cnt = 1; t[now].val = x;
        }
      splay(now, 0);
    }
    

    2.删除
    (x)的前驱(a),后继(b),把(a)旋到根,再把(b)旋到(a)的右儿子,这样(b)的左子树就只剩(x)唯一一个节点了。若有多个,(cnt--),否则清零。
    然后一定要把(b)旋转到根。刚开始我很不理解,多亏了学姐说:不旋转怎么更新平衡树啊!我才知道旋转还有一个作用,就是更新这个节点的所有祖先的值。跟线段树回溯的时候更新祖先节点一样。
    我刚开始写了一个(clear)函数,然后因为没传实参debug了半天……

    void del(int x)
    {
      int a = pre(x), b = nxt(x);		
      splay(a, 0); splay(b, a);
      int now = t[b].ch[0];
      if(t[now].cnt > 1) t[now].cnt--, splay(now, 0);
      else t[b].ch[0] = 0;
    }
    

    3.查询(x)数的排名(数据保证(x)存在)
    有一个很棒的做法是查找(x)并把(x)旋到根,然后返回根的左子树的大小就行。
    然而我刚开始是用(bst)的思路写的:如果(x)小于当前节点权值,到左子树找;如果等于,返回累加值(ret)加当前节点大小;否则,(ret)加上左子树和当前节点大小,然后去右子树找。
    这个思路是没问题的,写起来也不难,然而我被某谷的第(12)个点卡掉了:数据是这样的:先添加(50000)个点,接下来(50000)个操作全是查找(i)(50000)的排名。按我的写法是虽然是稳定的(O(n log{n})),但就是TLE;而换了我刚开始说的那个写法后,由于询问可能是(O(1))可能是(O(n log{n}))的,竟然迅速的AC了……
    注释掉的是(bst)的写法

    int queryKth(int x)
    {
      find(x);
      return t[t[root].ch[0]].siz;
      /*int now = root, ret = 0;
      while(1)
        {
          if(t[now].val > x) now = t[now].ch[0];
          else if(t[now].val == x) return ret + t[t[now].ch[0]].siz;
          else ret += t[t[now].ch[0]].siz + t[now].cnt, now = t[now].ch[1];
          }*/
    }
    

    4.查询排名为(x)的数
    (bst)的写法就行:如果(x)小于等于左子树大小,去左子树找;否则如果小于等于左子树加上当前节点的大小,返回当前节点权值;否则(x)减去左子树和节点大小,到右子树去找。

    int queryX(int k)
    {
      int now = root;
      while(1)
        {
          if(k <= t[t[now].ch[0]].siz) now = t[now].ch[0];
          else if(k <= t[t[now].ch[0]].siz + t[now].cnt) return t[now].val;
          else k -= (t[t[now].ch[0]].siz + t[now].cnt), now = t[now].ch[1];
        }
    }
    

    5,6.前驱,后继
    不说了

    void find(int x)
    {
      int now = root;
      if(!now) return;
      while(t[now].val != x && t[now].ch[x > t[now].val]) now = t[now].ch[x > t[now].val];
      splay(now, 0);
    }
    int pre(int x)
    {
      find(x);
      //_PrintTr(root);
      if(t[root].val < x) return root;
      int now = t[root].ch[0];
      while(t[now].ch[1]) now = t[now].ch[1];
      return now;
    }
    int nxt(int x)
    {
      find(x);
      //_PrintTr(root);
      if(t[root].val > x) return root;
      int now = t[root].ch[1];
      while(t[now].ch[0]) now = t[now].ch[0];
      return now;
    }
    

    最后当然要放完整代码啦。(自认为挺短的) ```c++ #include #include #include #include #include #include #include #include #include #include using namespace std; #define enter puts("") #define space putchar(' ') #define Mem(a, x) memset(a, x, sizeof(a)) #define rg register typedef long long ll; typedef double db; const int INF = 0x3f3f3f3f; const db eps = 1e-8; const int maxn = 1e5 + 5; inline ll read() { ll ans = 0; char ch = getchar(), last = ' '; while(!isdigit(ch)) {last = ch; ch = getchar();} while(isdigit(ch)) {ans = (ans << 1) + (ans << 3) + ch - '0'; ch = getchar();} if(last == '-') ans = -ans; return ans; } inline void write(ll x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar(x % 10 + '0'); }

    int n;
    struct Tree
    {
    int ch[2], fa;
    int siz, cnt, val;
    }t[maxn];
    int root, ncnt = 0;

    void _PrintTr(int now)
    {
    if(!now) return;
    printf("nd:%d val:%d ls:%d rs:%d ", now, t[now].val, t[t[now].ch[0]].val, t[t[now].ch[1]].val);
    _PrintTr(t[now].ch[0]); _PrintTr(t[now].ch[1]);
    }

    void pushup(int now)
    {
    t[now].siz = t[t[now].ch[0]].siz + t[t[now].ch[1]].siz + t[now].cnt;
    }
    void rotate(int x)
    {
    int y = t[x].fa, z = t[y].fa, k = (t[y].ch[1] == x);
    t[z].ch[t[z].ch[1] == y] = x; t[x].fa = z;
    t[y].ch[k] = t[x].ch[k ^ 1]; t[t[x].ch[k ^ 1]].fa = y;
    t[x].ch[k ^ 1] = y; t[y].fa = x;
    pushup(y); pushup(x);
    }
    void splay(int x, int s)
    {
    while(t[x].fa != s)
    {
    int y = t[x].fa, z = t[y].fa;
    if(z != s)
    {
    if((t[z].ch[1] == y) ^ (t[y].ch[1] == x)) rotate(x);
    else rotate(y);
    }
    rotate(x);
    }
    if(!s) root = x;
    }
    void insert(int x)
    {
    int now = root, f = 0;
    while(now && t[now].val != x) f = now, now = t[now].ch[x > t[now].val];
    if(now) t[now].cnt++;
    else
    {
    now = ++ncnt;
    if(f) t[f].ch[x > t[f].val] = now;
    t[now].fa = f;
    t[now].ch[0] = t[now].ch[1] = 0;
    t[now].siz = t[now].cnt = 1; t[now].val = x;
    }
    splay(now, 0);
    }
    void find(int x)
    {
    int now = root;
    if(!now) return;
    while(t[now].val != x && t[now].ch[x > t[now].val]) now = t[now].ch[x > t[now].val];
    splay(now, 0);
    }
    int pre(int x)
    {
    find(x);
    //_PrintTr(root);
    if(t[root].val < x) return root;
    int now = t[root].ch[0];
    while(t[now].ch[1]) now = t[now].ch[1];
    return now;
    }
    int nxt(int x)
    {
    find(x);
    //_PrintTr(root);
    if(t[root].val > x) return root;
    int now = t[root].ch[1];
    while(t[now].ch[0]) now = t[now].ch[0];
    return now;
    }
    void del(int x)
    {
    int a = pre(x), b = nxt(x);
    splay(a, 0); splay(b, a);
    int now = t[b].ch[0];
    if(t[now].cnt > 1) t[now].cnt--, splay(now, 0);
    else t[b].ch[0] = 0;
    }
    int queryKth(int x)
    {
    find(x);
    return t[t[root].ch[0]].siz;
    /int now = root, ret = 0;
    while(1)
    {
    if(t[now].val > x) now = t[now].ch[0];
    else if(t[now].val == x) return ret + t[t[now].ch[0]].siz;
    else ret += t[t[now].ch[0]].siz + t[now].cnt, now = t[now].ch[1];
    }
    /
    }
    int queryX(int k)
    {
    int now = root;
    while(1)
    {
    if(k <= t[t[now].ch[0]].siz) now = t[now].ch[0];
    else if(k <= t[t[now].ch[0]].siz + t[now].cnt) return t[now].val;
    else k -= (t[t[now].ch[0]].siz + t[now].cnt), now = t[now].ch[1];
    }
    }

    int main()
    {
    //freopen("test.in", "r", stdin);
    //freopen("ha.out", "w", stdout);
    insert(-INF); insert(INF);
    n = read();
    for(int i = 1; i <= n; ++i)
    {
    int op = read(), x = read();
    if(op == 1) insert(x);
    else if(op == 2) del(x);
    else if(op == 3) write(queryKth(x)), enter;
    else if(op == 4) write(queryX(x + 1)), enter;
    else if(op == 5) write(t[pre(x)].val), enter;
    else write(t[nxt(x)].val), enter;
    //_PrintTr(root);
    }
    return 0;
    }

  • 相关阅读:
    JS对文本框值的判断
    PostgreSQL导出表中数据
    postgreSQL中跨库查询在windows下的实现方法
    获取表中每个编号最新一条数据的集合
    开源跨平台声波传输库:Sonic
    Windows下配置cygwin和ndk编译环境
    char的定义在iOS和Android下是不同的
    stdout引发的curl 302跳转 crash
    WebBrowser内嵌页面的跨域调用问题
    【已解决】Ubuntu 12.04 LTS Source安装nodejs时出现"bash ./configure permission denied"
  • 原文地址:https://www.cnblogs.com/mrclr/p/10056927.html
Copyright © 2020-2023  润新知