• 【树结构】树链剖分简单分析


    【树结构】树链剖分

    当我们需要在一棵树上完成某些区间操作,而且要求复杂度严格保持在 lg 级别,那么树链剖分往往是不错的选择。

    所谓树链剖分,就是把树分割成链,把每条链放到线段树或其他数据结构里面维护。显然,只要我们保证每次区间操作涉及的链的个数为 O(lgn) ,就可以保证总查询或修改复杂度为O(lg2n)。一种常用的分割方式是 “轻重剖分” ,相关资料网上可以查询到。

    对于一个区间查询如“对a, b最短路径上的所有节点权值求和”,只需要用倍增处理出c = LCA(a, b)转化为对于一个节点和他祖先节点的区间求和。之后只需要不断检查c是否在当前区间内。如果在直接调用数据结构求和,如果不在则求这一区间的总和,并将节点向上推,直到depth[a] < depth[c]。

    以zjoi的树的统计一题为例给出代码:

    #include <bits/stdc++.h>
    using namespace std;
    const int maxm = 100005, maxn = 100005;
    
    int zkw[maxn * 4], sum[maxn * 4], N = 131072, t = 0;
    void update(int i, int k)
    {
        i += N - 1; zkw[i] = sum[i] = k;
        for (i >>= 1; i; i >>= 1) {
            zkw[i] = max(zkw[i << 1], zkw[(i << 1) + 1]);
            sum[i] = sum[i << 1] + sum[(i << 1) + 1];
        }
    }
    
    pair<int, int> query(int i, int j)
    {
        int ans = 0, maxi = INT_MIN, p, q;
        for (p = i + N - 1, q = j + N - 1; p < q; p >>= 1, q >>= 1) {
            if (p & 1) { ans += sum[p], maxi = max(maxi, zkw[p]); p++; }
            if (!(q & 1)) { ans += sum[q], maxi = max(maxi, zkw[q]); q--; }
        }
        if (p == q) ans += sum[p], maxi = max(maxi, zkw[p]);
        return make_pair(ans, maxi);
    }
    
    struct node {
        int to, next;
        node() { to = next = 0; }
    } edge[2 * maxm];
    int head[maxn], top = 0;
    int dat[maxn], siz[maxn], id[maxn], ind[maxn], hev[maxn], dep[maxn];
    int father[maxn][35];
    int n, m;
    
    void push(int i, int j) { edge[++top].to = j; edge[top].next = head[i]; head[i] = top; }
    int dfs1(int i)
    {
        siz[i] = 1;
        for (int k = head[i]; k; k = edge[k].next) {
            if (!dep[edge[k].to]) {
                dep[edge[k].to] = dep[i] + 1; father[edge[k].to][0] = i;
                siz[i] += dfs1(edge[k].to);
            }
        }
        return siz[i];
    }
    void dfs2(int i, int from)
    {
        ind[i] = from; update(++t, dat[i]); id[i] = t;
        if (!head[i]) return;
        hev[i] = 0;
        for (int k = head[i]; k; k = edge[k].next) {
            if (dep[edge[k].to] > dep[i] && siz[edge[k].to] > siz[hev[i]])
                hev[i] = edge[k].to;
        }
        if (!hev[i]) {return;}
        dfs2(hev[i], from);
        for (int k = head[i]; k; k = edge[k].next)
            if (dep[edge[k].to] > dep[i] && edge[k].to != hev[i])
                dfs2(edge[k].to, edge[k].to);
    }
    
    void travel(int, int);
    void init()
    {
        dep[1] = 1;
        memset(father, 0, sizeof father);
        dfs1(1);
        dfs2(1, 1);
        for (int j = 1; j <= 20; j++)
            for (int i = 1; i <= n; i++)
                father[i][j] = father[father[i][j-1]][j-1];
    }
    inline int lowbit(int i) { return i&(-i); }
    int lca(int a, int b)
    {
        if (dep[a] < dep[b]) swap(a, b);
        int dd = dep[a] - dep[b];
        while (dd) { a = father[a][(int)(log2(lowbit(dd)))]; dd -= lowbit(dd); }
        if (a == b) return a;
        for (int i = 20; i >= 0; i--)
            if (father[a][i] != father[b][i])
                a = father[a][i], b = father[b][i];
        return father[a][0];
    }
    
    int query_sum(int i, int j) // j is anc of i
    {
        if (dep[i] < dep[j]) return 0;
        if (dep[ind[i]] <= dep[j])
            return query(id[j], id[i]).first;
        return query_sum(father[ind[i]][0], j) + query(id[ind[i]], id[i]).first;
    }
    
    int query_max(int i, int j)
    {
        if (dep[i] < dep[j]) return INT_MIN;
        if (dep[ind[i]] <= dep[j])
            return query(id[j], id[i]).second;
        return max(query_max(father[ind[i]][0], j), query(id[ind[i]], id[i]).second);
    }
    
    inline void change(int i, int j) { update(id[i], j); }
    inline int read() { int a; scanf("%d", &a); return a; }
    
    int main()
    {
        memset(dep, 0, sizeof dep);
        memset(head, 0, sizeof head);
        memset(hev, 0, sizeof hev);
        memset(sum, 0, sizeof sum);
        memset(zkw, -127/3, sizeof zkw);
        n = read();
        for (int i = 1; i < n; i++) {
            int a, b; a = read(); b = read();
            push(a, b);
            push(b, a);
        }
        for (int i = 1; i <= n; i++)
            dat[i] = read();
        init();
        m = read();
        char str[10]; int a, b, c;
        for (int i = 1; i <= m; i++) {
            scanf("%s", str);
            a = read(); b = read();
            if (strcmp(str, "CHANGE") == 0) change(a, b);
            else if (strcmp(str, "QSUM") == 0) {
                c = lca(a, b);
                printf("%d
    ", query_sum(a, c)+query_sum(b, c)-query_sum(c, c));
            }
            else {
                c = lca(a, b);
                printf("%d
    ", max(query_max(a, c), query_max(b, c)));
            }
        }
        return 0;
    }
    
  • 相关阅读:
    docker使用常用命令:启动/重启/关闭docker
    golang 中内存模型
    【转】Linux fork操作之后发生了什么?又会共享什么呢?
    go检查channel是否关闭
    golang select case 用法
    【转】理解字节序 大端字节序和小端字节序
    【转】3种TCP连接异常的情况。
    react-window 多条列表数据加载(虚拟滚动)
    ts 中 interface 与 class 的区别
    js new一个对象的过程,实现一个简单的new方法
  • 原文地址:https://www.cnblogs.com/ljt12138/p/6684354.html
Copyright © 2020-2023  润新知