让我们通过考虑下面的问题来理解线段树。
给定一个数组arr[0 . . . n-1]
,我们要对数组执行这样的操作:
1 计算从下标l到r的元素之和,其中 0 <= l <= r <= n-1
2 修改数组指定元素的值arr[i] = x,其中 0 <= i <= n-1
一个简单的方案是从l
到r
执行循环,计算给定区间的元素之和。更新值的时候,简单地令arr[i] = x。第一个操作花费O(n)的时间,第二个操作花费O(1)的时间。
第二个方案是创建另外一个数组来存储从下标i开始的元素之和。这样一来,给定区间之和可以用O(1)的时间计算,但是更新需要花费O(n)的时间。这种方法适用于需要大量查询而更新操作较少的场景。
如果查询和更新的次数一样多呢?我们可以在O(log n)的时间内完成上述两种操作吗?
我们可以使用线段树来实现在O(log n)时间内完成上述两种操作。
线段树的表示:
1. 叶子节点存储输入的数组元素
2. 每一个内部节点表示某些叶子节点的合并(merge)。合并的方法可能会因问题而异。对于这个问题,合并指的是某个节点之下的所有叶子节点的和。
此处使用树的数组形式来表示线段树。对于下标i
的节点,其左孩子为 2*i+1
,右孩子为2*i+2
,父节点为floor( (i - 1)/2 )
根据给定的数组构建线段树:
我们从线段数组arr[0 . . . n-1]开始。每一次将现有的线段拆分成两半(如果当前线段的长度还不为1),然后在两半线段分别执行同样的过程,并且对于每一个这样的线段,我们存储相应节点的和。
构建出的线段树除最后一层外每一层都会填满。线段树是满二叉树(此处的满二叉树是指树中任意节点的度为0或2的二叉树,与国内计算机教材关于满二叉树的定义不同),因为我们在每一层都将线段拆分为两半。由于构建出的树总是拥有n个叶子节点的满二叉树,因此内部节点会有 n-1 个。因而节点总数为 2*n - 1。
线段树的高度为ceil( log n )
。由于树使用数组表示,并且需要维护父子索引之间的关系,为线段树分配的内存需要:2 * 2 ^ ceil( logn ) - 1
查询给定区间元素之和:
线段树构建完成之后,怎样获取给定区间的和呢?下面的伪代码展示了获取区间元素之和的过程。
int getSum(node, l, r) { if range of node is within l and r return value in node else if range of node is completely outside l and r return 0 else return getSum(node's left child, l, r) + getSum(node's right child, l, r) }
元素值的更新:
与树的构建与查询操作类似,更新操作也可以递归地完成。给定需要更新的数组下标。令diff为需要添加的值。我们从线段树的根节点开始,对所有位于给定区间之内的节点添加diff。如果节点不在区间范围之内,则不做任何修改。
线段树的实现:
下面展示了线段树的实现。程序实现了从任意数组构建线段树,以及查询与更新操作。
// 演示线段树构建、查询、更新等操作的示例程序 #include <stdio.h> #include <math.h> // 获取起止下标中点的工具函数 int getMid(int s, int e) { return s + (e - s)/2; } /* 获取数组给定区间之和的递归函数 下面是函数的参数列表 st --> 线段树的指针 index --> 线段树当前节点的下标。 初始传入根节点的下标为0 根节点的下标值不会变更 ss & se --> 线段树当前节点表示的原数组起止下标 亦即,st[index]的起止下标 qs & qe --> 查询区间的起止下标 */ int getSumUtil(int *st, int ss, int se, int qs, int qe, int index) { // 如果当前节点存储的线段是区间的一部分, // 返回当前线段的和 if (qs <= ss && qe >= se) return st[index]; // 如果节点存储的线段不在给定区间之内 if (se < qs || ss > qe) return 0; // 如果节点的线段与区间的一部分有交集 int mid = getMid(ss, se); return getSumUtil(st, ss, mid, qs, qe, 2*index+1) + getSumUtil(st, mid+1, se, qs, qe, 2*index+2); } /* 更新下标位于给定区间内节点值的递归函数, 下面是参数列表 st, index, ss and se 与getSumUtil() 一致 i --> 待更新元素的下标,指的是输入数组的下标。 diff --> 区间需要增加的值 */ void updateValueUtil(int *st, int ss, int se, int i, int diff, int index) { // Base Case: 如果输入下标在线段树范围之外 if (i < ss || i > se) return; // 如果输入下标在节点范围之内, // 则更新节点及其孩子的值 st[index] = st[index] + diff; if (se != ss) { int mid = getMid(ss, se); updateValueUtil(st, ss, mid, i, diff, 2*index + 1); updateValueUtil(st, mid+1, se, i, diff, 2*index + 2); } } // 更新输入数组与线段树值的函数。 // 使用了函数 updateValueUtil() 来更新线段树的值 void updateValue(int arr[], int *st, int n, int i, int new_val) { // 检查错误的输入下标 if (i < 0 || i > n-1) { printf("Invalid Input"); return; } // 计算新值与老值之间的差值 int diff = new_val - arr[i]; // 更新数组的值 arr[i] = new_val; // 更新线段树节点的值 updateValueUtil(st, 0, n-1, i, diff, 0); } // 返回下标qs(查询起点)到qe(查询终点)的元素之和。 // 主要使用了函数getSumUtil() int getSum(int *st, int n, int qs, int qe) { // 检查错误的输入 if (qs < 0 || qe > n-1 || qs > qe) { printf("Invalid Input"); return -1; } return getSumUtil(st, 0, n-1, qs, qe, 0); } // 递归函数,为数组[ss..se]构建线段树 // si 是线段树st内当前节点的下标 int constructSTUtil(int arr[], int ss, int se, int *st, int si) { // 如果数组只包含一个元素 // 将其存储与线段树的当前节点并返回 if (ss == se) { st[si] = arr[ss]; return arr[ss]; } // 如果有不止一个元素, // 则递归计算左右子树,并将两者之和存储与节点内,并返回 int mid = getMid(ss, se); st[si] = constructSTUtil(arr, ss, mid, st, si*2+1) + constructSTUtil(arr, mid+1, se, st, si*2+2); return st[si]; } /* 从给定数组构建线段树的函数。 函数为线段树分配内存空间,并调用函数constructSTUtil() 来填充分配的内存 */ int *constructST(int arr[], int n) { // 为线段树分配内存空间 int x = (int)(ceil(log2(n))); //线段树的高度 int max_size = 2*(int)pow(2, x) - 1; //线段树的最大容量 int *st = new int[max_size]; // 填充线段树st constructSTUtil(arr, 0, n-1, st, 0); // 返回构建的线段树 return st; } // 上述函数的测试程序 int main() { int arr[] = {1, 3, 5, 7, 9, 11}; int n = sizeof(arr)/sizeof(arr[0]); // 从给定数组构建线段树 int *st = constructST(arr, n); // 输出下标1 到 3的元素之和 printf("Sum of values in given range = %d ", getSum(st, n, 1, 3)); // 更新: 令 arr[1] = 10 // 并更新相应的线段树节点 updateValue(arr, st, n, 1, 10); // 输出更新后的和值 printf("Updated sum of values in given range = %d ", getSum(st, n, 1, 3)); return 0; }
程序输出:
Sum of values in given range = 15 Updated sum of values in given range = 22
时间复杂度:
线段树构建的时间复杂度为O(n)。总计有2n-1个节点,每一个节点在树构建过程中只被运算一次。
查询的时间复杂度为O(log n)。要查询区间和,我们在每一层至多处理4个节点,并且层的总数为O(log n)。
更新的时间复杂度也是O(log n)。要更新一个叶子节点,我们每一层处理一个节点,并且层的总数为O(log n)。
原文链接:http://www.geeksforgeeks.org/segment-tree-set-1-sum-of-given-range/
本文链接:http://bookshadow.com/weblog/2015/08/13/segment-tree-set-1-sum-of-given-range/
较为清晰的代码:
/** * Definition of Interval: * public classs Interval { * int start, end; * Interval(int start, int end) { * this.start = start; * this.end = end; * } */ public class Solution { /* * @param A: An integer list * @param queries: An query list * @return: The result list */ class SegmentTreeNode { public int start; public int end; public long sum; SegmentTreeNode left; SegmentTreeNode right; public SegmentTreeNode(int start, int end, long sum) { this.start = start; this.end = end; this.sum = sum; this.left = null; this.right = null; } } public SegmentTreeNode build(int start, int end, int[] A) { if (start > end) { return null; } if (start == end) { return new SegmentTreeNode(start, end, A[start]); } SegmentTreeNode root = new SegmentTreeNode(start, end, 0); int mid = start + (end - start) / 2; root.left = build(start, mid, A); root.right = build(mid + 1, end, A); if (root.left != null) { root.sum += root.left.sum; } if (root.right != null) { root.sum += root.right.sum; } return root; } public long query(SegmentTreeNode root, int start, int end) { if (start <= root.start && end >= root.end) { return root.sum; } int mid = root.start + (root.end - root.start) / 2; long ans = 0; if (start <= mid) { ans += query(root.left, start, end); } if (end > mid) { ans += query(root.right, start, end); } return ans; } SegmentTreeNode root; public List<Long> intervalSum(int[] A, List<Interval> queries) { root = build(0, A.length - 1, A); List<Long> list = new ArrayList<>(); for (Interval num : queries) { long res = query(root, num.start, num.end); list.add(res); } return list; } }