前言
线段树(区间树)是什么呢?有了二叉树、二分搜索树,线段树又是干什么的呢?最经典的线段树问题:区间染色;正如它的名字而言,主要解决区间的问题
一、线段树说明
1、什么是线段树?
线段树首先是二叉树,并且是平衡二叉树(它是一 棵空树或它的左右两个子树的高度差的绝对值不超过1,并且左右两个子树都是一棵平衡二叉树),并且具有二分性质。
如下图,就是一颗线段树:
假如,用数组表示线段树,如果区间有n个元素,数组表示需要有多少节点?
2、4n节点推导过程
要进行一下,如果对推导过程不感兴趣的,可以直接记住结论,需要4n个节点,推导过程如下图: PS:依旧是全博客园最丑图,当感觉有进步啊!是不是推荐一下,鼓励一下啊
说明:感觉用尽了洪荒之力,才推导出来了。感觉高考之后再也不会用到等比公式了,但又用到了,还是缘分未尽啊,哈哈哈!最后,都放弃了,一直推导不出来,忘却了最后一层的null,假设是满二叉树,按最大值进行估算,所以4n是完全够大的!
二、为什么要使用线段树
线段树主要解决一些区间问题的,如下:
1、区间染色
有一面墙,长度为n,每次选择一段墙进行染色,m次操作之后,我们可以看见多少种颜色?
2、区间查询
查询区间[i,j]的最大值、最小值,或者区间数字和;实质:基于区间的统计查询。
例如:2018年注册用户中消费最高的用户?消费最低的用户?学习最长时间的用户?
三、代码实现
1、创建线段树
二叉树具有天然递归性质,所以用递归相对简单,用迭代也是可以的,我才用递归实现,代码如下:
template<class T> class SegmentTree { private: T *tree; T *data; int size; std::function<T(T, T)> function; int leftChild(int index) { //左孩子下标;例如用数组存储,根节点是下标0,则左孩子为1,右孩子为2 return index * 2 + 1; } int rightChild(int index) { //右孩子下标 return index * 2 + 2; } void buildSegmentTree(int treeIndex, int l, int r) { if (l == r) { tree[treeIndex] = data[l]; return; } int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); int mid = l + (r - l) / 2; //中间值求法,防止整型溢出 buildSegmentTree(leftTreeIndex, l, mid); //构建左子树 buildSegmentTree(rightTreeIndex, mid + 1, r); //构建右子树 tree[treeIndex] = function(tree[leftTreeIndex], tree[rightTreeIndex]); } public: SegmentTree(T arr[], int n, std::function<T(T, T)> function) { //构造函数,构建一棵树 this->function = function; data = new T[n]; for (int i = 0; i < n; ++i) { data[i] = arr[i]; } tree = new T[n * 4]; //分配4n节点 size = n; buildSegmentTree(0, 0, size - 1); } };
2、线段树查询
线段树具有二分查找性质,所以二分查找那种思路就可以了,代码如下:
T query(int treeIndex, int l, int r, int queryL, int queryR) { if (l == queryL && r == queryR) { return tree[treeIndex]; } int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (queryL >= mid + 1) { return query(rightTreeIndex, mid + 1, r, queryL, queryR); } else if (queryR <= mid) { return query(leftTreeIndex, l, mid, queryL, queryR); } T leftResult = query(leftTreeIndex, l, mid, queryL, mid); T rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR); return function(leftResult, rightResult); } T query(int queryL, int queryR) { assert(queryL >= 0 && queryL < size && queryR >= 0 && queryR < size && queryL <= queryR); return query(0, 0, size - 1, queryL, queryR); }
3、整体代码
SegmentTree.h如下:
#ifndef SEGMENT_TREE_SEGMENTTREE_H #define SEGMENT_TREE_SEGMENTTREE_H #include <cassert> #include <functional> template<class T> class SegmentTree { private: T *tree; T *data; int size; std::function<T(T, T)> function; int leftChild(int index) { return index * 2 + 1; } int rightChild(int index) { return index * 2 + 2; } void buildSegmentTree(int treeIndex, int l, int r) { if (l == r) { tree[treeIndex] = data[l]; return; } int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); int mid = l + (r - l) / 2; buildSegmentTree(leftTreeIndex, l, mid); buildSegmentTree(rightTreeIndex, mid + 1, r); tree[treeIndex] = function(tree[leftTreeIndex], tree[rightTreeIndex]); } T query(int treeIndex, int l, int r, int queryL, int queryR) { if (l == queryL && r == queryR) { return tree[treeIndex]; } int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (queryL >= mid + 1) { return query(rightTreeIndex, mid + 1, r, queryL, queryR); } else if (queryR <= mid) { return query(leftTreeIndex, l, mid, queryL, queryR); } T leftResult = query(leftTreeIndex, l, mid, queryL, mid); T rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR); return function(leftResult, rightResult); } public: SegmentTree(T arr[], int n, std::function<T(T, T)> function) { this->function = function; data = new T[n]; for (int i = 0; i < n; ++i) { data[i] = arr[i]; } tree = new T[n * 4]; size = n; buildSegmentTree(0, 0, size - 1); } int getSize() { return size; } T get(int index) { assert(index >= 0 && index < size); return data[index]; } T query(int queryL, int queryR) { assert(queryL >= 0 && queryL < size && queryR >= 0 && queryR < size && queryL <= queryR); return query(0, 0, size - 1, queryL, queryR); } void print() { std::cout << "["; for (int i = 0; i < size * 4; ++i) { if (tree[i] != NULL) { std::cout << tree[i]; } else { std::cout << "0"; } if (i != size * 4 - 1) { std::cout << ", "; } } std::cout << "]" << std::endl; } }; #endif //SEGMENT_TREE_SEGMENTTREE_H
main.cpp如下:
#include <iostream> #include "SegmentTree.h" int main() { int nums[] = {-2, 0, 3, -5, 2, -1}; SegmentTree<int> *segmentTree = new SegmentTree<int>(nums, sizeof(nums) / sizeof(int), [](int a, int b) -> int { return a + b; }); std::cout << segmentTree->query(2,5) << std::endl; segmentTree->print(); return 0; }
4、演示
运行结果,如下:
5、时间复杂度分析
更新 O(logn)
查询 O(logn)