• 线段树 | 第1讲 (给定区间求和)(转)


    让我们通过考虑下面的问题来理解线段树。

    给定一个数组arr[0 . . . n-1],我们要对数组执行这样的操作:

    1 计算从下标l到r的元素之和,其中 0 <= l <= r <= n-1
    ​2 修改数组指定元素的值arr[i] = x,其中 0 <= i <= n-1

    一个简单的方案是从lr执行循环,计算给定区间的元素之和。更新值的时候,简单地令arr[i] = x。第一个操作花费O(n)的时间,第二个操作花费O(1)的时间。

    第二个方案是创建另外一个数组来存储从下标i开始的元素之和。这样一来,给定区间之和可以用O(1)的时间计算,但是更新需要花费O(n)的时间。这种方法适用于需要大量查询而更新操作较少的场景。

    如果查询和更新的次数一样多呢?我们可以在O(log n)的时间内完成上述两种操作吗?

    我们可以使用线段树来实现在O(log n)时间内完成上述两种操作。

    线段树的表示:

    1. 叶子节点存储输入的数组元素
    2. 每一个内部节点表示某些叶子节点的合并(merge)。合并的方法可能会因问题而异。对于这个问题,合并指的是某个节点之下的所有叶子节点的和。

    此处使用树的数组形式来表示线段树。对于下标i的节点,其左孩子为 2*i+1,右孩子为2*i+2,父节点为floor( (i - 1)/2 )

    segment-tree

    根据给定的数组构建线段树:

    我们从线段数组arr[0 . . . n-1]开始。每一次将现有的线段拆分成两半(如果当前线段的长度还不为1),然后在两半线段分别执行同样的过程,并且对于每一个这样的线段,我们存储相应节点的和。

    构建出的线段树除最后一层外每一层都会填满。线段树是满二叉树(此处的满二叉树是指树中任意节点的度为0或2的二叉树,与国内计算机教材关于满二叉树的定义不同),因为我们在每一层都将线段拆分为两半。由于构建出的树总是拥有n个叶子节点的满二叉树,因此内部节点会有 n-1 个。因而节点总数为 2*n - 1。

    线段树的高度为ceil( log n )。由于树使用数组表示,并且需要维护父子索引之间的关系,为线段树分配的内存需要:2 * 2 ^ ceil( logn ) - 1

    查询给定区间元素之和:

    线段树构建完成之后,怎样获取给定区间的和呢?下面的伪代码展示了获取区间元素之和的过程。

    int getSum(node, l, r) 
    {
       if range of node is within l and r
            return value in node
       else if range of node is completely outside l and r
            return 0
       else
        return getSum(node's left child, l, r) + 
               getSum(node's right child, l, r)
    }

    元素值的更新:

    与树的构建与查询操作类似,更新操作也可以递归地完成。给定需要更新的数组下标。令diff为需要添加的值。我们从线段树的根节点开始,对所有位于给定区间之内的节点添加diff。如果节点不在区间范围之内,则不做任何修改。

    线段树的实现:

    下面展示了线段树的实现。程序实现了从任意数组构建线段树,以及查询与更新操作。

    // 演示线段树构建、查询、更新等操作的示例程序
    #include <stdio.h>
    #include <math.h>
     
    // 获取起止下标中点的工具函数
    int getMid(int s, int e) {  return s + (e - s)/2;  }
     
    /*  获取数组给定区间之和的递归函数
        下面是函数的参数列表
     
        st    --> 线段树的指针
        index --> 线段树当前节点的下标。 初始传入根节点的下标为0
                 根节点的下标值不会变更
        ss & se  --> 线段树当前节点表示的原数组起止下标
                     亦即,st[index]的起止下标
        qs & qe  --> 查询区间的起止下标 */
    int getSumUtil(int *st, int ss, int se, int qs, int qe, int index)
    {
        // 如果当前节点存储的线段是区间的一部分,
        // 返回当前线段的和
        if (qs <= ss && qe >= se)
            return st[index];
     
        // 如果节点存储的线段不在给定区间之内
        if (se < qs || ss > qe)
            return 0;
     
        // 如果节点的线段与区间的一部分有交集
        int mid = getMid(ss, se);
        return getSumUtil(st, ss, mid, qs, qe, 2*index+1) +
               getSumUtil(st, mid+1, se, qs, qe, 2*index+2);
    }
     
    /* 更新下标位于给定区间内节点值的递归函数,
        下面是参数列表
        st, index, ss and se 与getSumUtil() 一致
        i    --> 待更新元素的下标,指的是输入数组的下标。
       diff --> 区间需要增加的值 */
    void updateValueUtil(int *st, int ss, int se, int i, int diff, int index)
    {
        // Base Case: 如果输入下标在线段树范围之外
        if (i < ss || i > se)
            return;
     
        // 如果输入下标在节点范围之内,
        // 则更新节点及其孩子的值
        st[index] = st[index] + diff;
        if (se != ss)
        {
            int mid = getMid(ss, se);
            updateValueUtil(st, ss, mid, i, diff, 2*index + 1);
            updateValueUtil(st, mid+1, se, i, diff, 2*index + 2);
        }
    }
     
    // 更新输入数组与线段树值的函数。
    // 使用了函数 updateValueUtil() 来更新线段树的值
    void updateValue(int arr[], int *st, int n, int i, int new_val)
    {
        // 检查错误的输入下标
        if (i < 0 || i > n-1)
        {
            printf("Invalid Input");
            return;
        }
     
        // 计算新值与老值之间的差值
        int diff = new_val - arr[i];
     
        // 更新数组的值
        arr[i] = new_val;
     
        // 更新线段树节点的值
        updateValueUtil(st, 0, n-1, i, diff, 0);
    }
     
    // 返回下标qs(查询起点)到qe(查询终点)的元素之和。
    // 主要使用了函数getSumUtil()
    int getSum(int *st, int n, int qs, int qe)
    {
        // 检查错误的输入
        if (qs < 0 || qe > n-1 || qs > qe)
        {
            printf("Invalid Input");
            return -1;
        }
     
        return getSumUtil(st, 0, n-1, qs, qe, 0);
    }
     
    // 递归函数,为数组[ss..se]构建线段树
    // si 是线段树st内当前节点的下标
    int constructSTUtil(int arr[], int ss, int se, int *st, int si)
    {
        // 如果数组只包含一个元素
        // 将其存储与线段树的当前节点并返回
        if (ss == se)
        {
            st[si] = arr[ss];
            return arr[ss];
        }
     
        // 如果有不止一个元素,
        // 则递归计算左右子树,并将两者之和存储与节点内,并返回
        int mid = getMid(ss, se);
        st[si] =  constructSTUtil(arr, ss, mid, st, si*2+1) +
                  constructSTUtil(arr, mid+1, se, st, si*2+2);
        return st[si];
    }
     
    /* 从给定数组构建线段树的函数。
       函数为线段树分配内存空间,并调用函数constructSTUtil()
       来填充分配的内存 */
    int *constructST(int arr[], int n)
    {
        // 为线段树分配内存空间
        int x = (int)(ceil(log2(n))); //线段树的高度
        int max_size = 2*(int)pow(2, x) - 1; //线段树的最大容量
        int *st = new int[max_size];
     
        // 填充线段树st
        constructSTUtil(arr, 0, n-1, st, 0);
     
        // 返回构建的线段树
        return st;
    }
     
    // 上述函数的测试程序
    int main()
    {
        int arr[] = {1, 3, 5, 7, 9, 11};
        int n = sizeof(arr)/sizeof(arr[0]);
     
        // 从给定数组构建线段树
        int *st = constructST(arr, n);
     
        // 输出下标1 到 3的元素之和
        printf("Sum of values in given range = %d
    ", getSum(st, n, 1, 3));
     
        // 更新: 令 arr[1] = 10
        //  并更新相应的线段树节点
        updateValue(arr, st, n, 1, 10);
     
        // 输出更新后的和值
        printf("Updated sum of values in given range = %d
    ",
                                                      getSum(st, n, 1, 3));
     
        return 0;
    }

    程序输出:

    Sum of values in given range = 15
    Updated sum of values in given range = 22

    时间复杂度:

    线段树构建的时间复杂度为O(n)。总计有2n-1个节点,每一个节点在树构建过程中只被运算一次。

    查询的时间复杂度为O(log n)。要查询区间和,我们在每一层至多处理4个节点,并且层的总数为O(log n)。

    更新的时间复杂度也是O(log n)。要更新一个叶子节点,我们每一层处理一个节点,并且层的总数为O(log n)。

    原文链接:http://www.geeksforgeeks.org/segment-tree-set-1-sum-of-given-range/

    本文链接:http://bookshadow.com/weblog/2015/08/13/segment-tree-set-1-sum-of-given-range/

    较为清晰的代码:

    /**
     * Definition of Interval:
     * public classs Interval {
     *     int start, end;
     *     Interval(int start, int end) {
     *         this.start = start;
     *         this.end = end;
     *     }
     */
    
    
    public class Solution {
        /*
         * @param A: An integer list
         * @param queries: An query list
         * @return: The result list
         */
        class SegmentTreeNode {
            public int start;
            public int end;
            public long sum;
            SegmentTreeNode left;
            SegmentTreeNode right;
            public SegmentTreeNode(int start, int end, long sum) {
                this.start = start;
                this.end = end;
                this.sum = sum;
                this.left = null;
                this.right = null;
            }
        }
    
        public SegmentTreeNode build(int start, int end, int[] A) {
            if (start > end) {
                return null;
            }
        
            if (start == end) {
                return new SegmentTreeNode(start, end, A[start]);
            }
        
            SegmentTreeNode root = new SegmentTreeNode(start, end, 0);
            int mid = start + (end - start) / 2;
            root.left = build(start, mid, A);
            root.right = build(mid + 1, end, A);
            if (root.left != null) {
                root.sum += root.left.sum;
            }
            if (root.right != null) {
                root.sum += root.right.sum;
            }
            return root;
        }
    
        public long query(SegmentTreeNode root, int start, int end) {
            if (start <= root.start && end >= root.end) {
                return root.sum;
            }
        
            int mid = root.start + (root.end - root.start) / 2;
            long ans = 0;
            if (start <= mid) {
                ans += query(root.left, start, end);
            } 
            if (end > mid) {
                ans += query(root.right, start, end);
            }
        
            return ans;
        }
     
        SegmentTreeNode root;
        public List<Long> intervalSum(int[] A, List<Interval> queries) {
            root = build(0, A.length - 1, A);
            List<Long> list = new ArrayList<>();
            
            for (Interval num : queries) {
                long res = query(root, num.start, num.end);
                list.add(res);
            }
            
            return list;
        }
    }
    联系方式:emhhbmdfbGlhbmcxOTkxQDEyNi5jb20=
  • 相关阅读:
    Spring Cloud 入门教程6、Hystrix Dashboard监控数据聚合(Turbine)
    Spring Cloud 入门教程5、服务容错监控:Hystrix Dashboard
    Spring Cloud 入门教程4、服务容错保护:断路器(Hystrix)
    Spring Cloud 入门教程3、服务消费者(Feign)
    JS==和===总结
    怀疑与批判
    Java多线程速记手册
    编译原理
    C系、Java、JavaScript、C#、PHP、Swift基本语法对比
    单例模式番外篇
  • 原文地址:https://www.cnblogs.com/zl1991/p/13229435.html
Copyright © 2020-2023  润新知