转自:http://finaltheory.info/?p=249
可持久化数据结构之主席树
引言
首先引入CLJ论文中的定义:
- 所谓的“持久化数据结构”,就是保存这个数据结构的所有历史版本,同时利用它们之间的共用数据减少时间和空间的消耗。
本文主要讨论两种可持久化线段树的算法思想、具体实现以及编码技巧。
核心思想
-
可持久化线段树是利用函数式编程的思想,对记录的数据只赋值不修改,每次插入一个数据后保存一个历史版本,然后利用线段树的结构完全相同,可以直接相减的特性进行区间询问。
-
我们以经典的区间第K大问题为例:输入n个数字组成的序列,询问序列中区间[l, r]上面的第K大的元素为何。其中第K大定义为:将区间[l, r]按升序排列后的第K个元素。
无修改的区间第K大
-
我们先考虑简化的问题:我们要询问整个区间内的第K大。这样我们对值域建线段树,每个节点记录这个区间所包含的元素个数,建树和查询时的区间范围用递归参数传递,然后用二叉查找树的询问方式即可:即如果左边元素个数sum>=K,递归查找左子树第K大,否则递归查找右子树第K – sum大,直到返回叶子的值。
-
现在我们要回答对于区间[l, r]的第K大询问。如果我们能够得到一个插入原序列中[1, l – 1]元素的线段树,和一颗插入了[1, r]元素的线段树,由于线段树是开在值域上,区间长度是一定的,所以结构也必然是完全相同的,我们可以直接对这两颗线段树进行相减,得到的是相当于插入了区间[l ,r]元素的线段树。注意这里利用到的区间相减性质,实际上是用两颗不同历史版本的线段树进行相减:一颗是插入到第l-1个元素的旧树,一颗是插入到第r元素的新树。
-
这样相减之后得到的是相当于只插入了原序列中[l, r]元素的一颗记录了区间数字个数的线段树。直接对这颗线段树按照BST的方式询问,即可得到区间第k大。
-
这种做法是可行的,但是我们显然不能每次插入一个元素,就从头建立一颗全新的线段树,否则内存开销无法承受。事实上,每次插入一个新的元素时,我们不需要新建所有的节点,而是只新建增加的节点。也就是从根节点出发,先新建节点并复制原节点的值,然后进行修改即可。
-
这样我们我们每到一个节点,只需要修改左儿子或者右儿子其一的信息,一直递归到叶子后结束,修改的节点数量就是树高,也就是新建了不超过树高个节点,内存开销就可以承受了。
-
注意我们对root[0]也就是插入了零个元素的那颗树,记录的左右儿子指针都是0,这样我们就可以用这一个节点表示一个任意结构的空树而不需要显式建树。这是因为对于这个节点,不管你再怎么递归,都是指向这个节点本身,里面记录的元素个数就是零。
有修改的区间第K大
-
当我们要求能够修改元素时,如果还是按照原来的方式保存历史版本,那么修改一个元素后会影响到它后面的所有建好的线段树,这会导致时间开销无法承受。
-
注意观察的话我们会发现,对于root[i]表示的这颗线段树,它保存的是从第一个元素开始插入到第i个元素后的数字区间。也就是说每次我们进行线段树区间相减时,我们是对两个前缀和[1, l – 1]和[1, r]进行了相减。
-
众所周知,我们有一种非常巧妙和简洁的方式来快速维护一个序列的前缀和,那就是树状数组。我们引入树状数组来快速求出线段树序列的前缀和,就能够以增加一个logn复杂度的代价,来让函数式线段树支持单点修改操作。
-
但是在代码的具体实现中存在一个相当重要的细节。如果我们完全按照树状数组维护前缀和的方式,插入原序列中一个一个元素在n颗空树上维护前缀和,那么所需要占用的空间是非常大的,因为增加了很多不必要的节点,而且建树的时间复杂度也增加了一个logn。
-
事实上我们可以先对于初始序列,按照无修改的方式建n颗线段树并且建好后不再修改;然后再对每次修改,在另n颗空树上维护前缀和。两者加在一起之后,就得到了支持修改的插入了原序列[l, r]元素的线段树,对它进行询问即可。
代码实现
部分技巧
-
如上文所述,对于无修改的情况,可以利用空节点可以不断递归的性质来省略显式建树的步骤。
-
对于有修改的情况,在进行区间询问时,传递的就不能仅仅是两个历史版本的线段树,而是用于求出前缀和的所有线段树,这样就得用全局数组来记录这些子树的根节点编号,并且在决定是往左儿子还是右儿子递归时,用各个子树根节点的左儿子/右儿子编号来更新这个数组。
-
如果需要压缩内存占用,可以先读入所有修改,然后将出现过的所有不同数值排序去重之后离散化,这样就可以将线段树开在一个小很多的值域上面,将原序列和修改中的每个元素用二分查找映射过去即可。
无修改(POJ2104/HDU2665)
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 #define MAX 100010 5 #define CLR(arr,val) memset(arr,val,sizeof(arr)) 6 using namespace std; 7 const int INF = 0x3f3f3f3f; 8 //记录原数组、排序后的数组、每个元素对应的根节点 9 int nums[MAX], sorted[MAX], root[MAX]; 10 int cnt; 11 struct TMD 12 { 13 int sum, L_son, R_son; 14 } Tree[MAX<<5]; 15 inline int CreateNode( int _sum, int _L_son, int _R_son ) 16 { 17 int idx = ++cnt; 18 Tree[idx].sum = _sum; 19 Tree[idx].L_son = _L_son; 20 Tree[idx].R_son = _R_son; 21 return idx; 22 } 23 void Insert( int & root, int pre_rt, int pos, int L, int R ) 24 { 25 //从根节点往下更新到叶子,新建立出一路更新的节点,这样就是一颗新树了。 26 root = CreateNode( Tree[pre_rt].sum + 1, Tree[pre_rt].L_son, Tree[pre_rt].R_son ); 27 if ( L == R ) return; 28 int M = ( L + R ) >> 1; 29 if ( pos <= M ) 30 Insert( Tree[root].L_son, Tree[pre_rt].L_son, pos, L, M ); 31 else 32 Insert( Tree[root].R_son, Tree[pre_rt].R_son, pos, M + 1, R ); 33 } 34 int Query( int S, int E, int L, int R, int K ) 35 { 36 if ( L == R ) return L; 37 int M = ( L + R ) >> 1; 38 //下面计算的sum就是当前询问的区间中,左儿子中的元素个数。 39 int sum = Tree[Tree[E].L_son].sum - Tree[Tree[S].L_son].sum; 40 if ( K <= sum ) 41 return Query( Tree[S].L_son, Tree[E].L_son, L, M, K ); 42 else 43 return Query( Tree[S].R_son, Tree[E].R_son, M + 1, R, K - sum ); 44 } 45 int main() 46 { 47 int n, m, num, pos, T; 48 while ( scanf("%d %d", &n, &m) != EOF ) 49 { 50 cnt = 0; root[0] = 0; 51 for ( int i = 1; i <= n; ++i ) 52 { 53 scanf("%d", &nums[i]); 54 sorted[i] = nums[i]; 55 } 56 sort( sorted + 1, sorted + 1 + n ); 57 num = unique( sorted + 1, sorted + n + 1 ) - ( sorted + 1 ); 58 for ( int i = 1; i <= n; ++i ) 59 { 60 //实际上是对每个元素建立了一颗线段树,保存其根节点 61 pos = lower_bound( sorted + 1, sorted + num + 1, nums[i] ) - sorted; 62 Insert( root[i], root[i - 1], pos, 1, num ); 63 } 64 int l, r, k; 65 while ( m-- ) 66 { 67 scanf("%d %d %d", &l, &r, &k); 68 pos = Query( root[l - 1], root[r], 1, num, k ); 69 printf("%d ", sorted[pos]); 70 } 71 } 72 }
有修改(ZOJ2112/BZOJ1901)
1 #define CLR(arr,val) memset(arr,val,sizeof(arr)) 2 using namespace std; 3 const int MAX = 50010; 4 const int MAX_q = 10010; 5 const int INF = 0x3f3f3f3f; 6 int nums[MAX], all_val[MAX + MAX_q], root[MAX<<1], prefix_l[100], prefix_r[100]; 7 int cnt, p[2]; 8 struct 9 { 10 int a, b, c; 11 char type; 12 } Querys[MAX_q]; 13 struct TMD 14 { 15 int sum, L_son, R_son; 16 } Tree[MAX*40]; 17 inline int Lowbit( int x ) 18 { 19 return x & (-x); 20 } 21 inline int CreateNode( int _sum, int _L_son, int _R_son ) 22 { 23 int idx = ++cnt; 24 Tree[idx].sum = _sum; 25 Tree[idx].L_son = _L_son; 26 Tree[idx].R_son = _R_son; 27 return idx; 28 } 29 void Build( int & root, int pre_rt, int pos, int L, int R ) 30 { 31 root = CreateNode( Tree[pre_rt].sum + 1, Tree[pre_rt].L_son, Tree[pre_rt].R_son ); 32 if ( L == R ) return; 33 int M = ( L + R ) >> 1; 34 if ( pos <= M ) 35 Build( Tree[root].L_son, Tree[pre_rt].L_son, pos, L, M ); 36 else 37 Build( Tree[root].R_son, Tree[pre_rt].R_son, pos, M + 1, R ); 38 } 39 void Insert( int & root, int pos, int L, int R, int val ) 40 { 41 //如果这颗子树没有被建立,就新建一个节点 42 if ( !root ) 43 root = CreateNode( 0, 0, 0 ); 44 Tree[root].sum += val; 45 if ( L == R ) return; 46 int M = ( L + R ) >> 1; 47 if ( pos <= M ) 48 Insert( Tree[root].L_son, pos, L, M, val ); 49 else 50 Insert( Tree[root].R_son, pos, M + 1, R, val ); 51 } 52 int Query( int L, int R, int K ) 53 { 54 if ( L == R ) return L; 55 int M = ( L + R ) >> 1, sum = 0; 56 //计算前缀和 57 for ( int i = 0; i < p[0]; i++ ) 58 sum += Tree[Tree[prefix_r[i]].L_son].sum; 59 for ( int i = 0; i < p[1]; i++ ) 60 sum -= Tree[Tree[prefix_l[i]].L_son].sum; 61 if ( K <= sum ) { 62 //更新用于计算前缀和的子树根节点编号 63 for ( int i = 0; i < p[0]; i++ ) 64 prefix_r[i] = Tree[prefix_r[i]].L_son; 65 for ( int i = 0; i < p[1]; i++ ) 66 prefix_l[i] = Tree[prefix_l[i]].L_son; 67 return Query( L, M, K ); 68 } else { 69 for ( int i = 0; i < p[0]; i++ ) 70 prefix_r[i] = Tree[prefix_r[i]].R_son; 71 for ( int i = 0; i < p[1]; i++ ) 72 prefix_l[i] = Tree[prefix_l[i]].R_son; 73 return Query( M + 1, R, K - sum ); 74 } 75 } 76 int main() 77 { 78 int n, m, p_val, num; 79 char str[5]; 80 int T; scanf("%d", &T); 81 while ( T-- ) 82 { 83 scanf("%d %d", &n, &m); 84 cnt = 0; p_val = n + 1; 85 for ( int i = 1; i <= n; ++i ) 86 { 87 scanf("%d", &nums[i]); 88 all_val[i] = nums[i]; 89 } 90 //读入所有修改并离散化 91 for ( int i = 0; i < m; ++i ) 92 { 93 scanf("%s %d %d", str, &Querys[i].a, &Querys[i].b); 94 Querys[i].type = str[0]; 95 if ( str[0] == 'Q' ) scanf("%d", &Querys[i].c); 96 else all_val[p_val++] = Querys[i].b; 97 } 98 sort( all_val + 1, all_val + p_val ); 99 num = unique( all_val + 1, all_val + p_val ) - ( all_val + 1 ); 100 //这里直接将初始数字序列映射到离散化后的值域上 101 for ( int i = 1; i <= n; ++i ) 102 nums[i] = lower_bound( all_val + 1, all_val + num + 1, nums[i] ) - all_val; 103 for ( int i = 1; i <= n; ++i ) 104 Build( root[i + n], root[i - 1 + n], nums[i], 1, num ); 105 for ( int i = 0; i < m; ++i ) 106 if ( Querys[i].type == 'Q' ) { 107 p[0] = p[1] = 1; 108 //初始化用于计算前缀和的线段树根节点 109 prefix_r[0] = root[Querys[i].b + n]; 110 prefix_l[0] = root[Querys[i].a - 1 == 0 ? 0 : Querys[i].a - 1 + n]; 111 for ( int arr = Querys[i].b; arr; arr -= Lowbit(arr) ) 112 prefix_r[p[0]++] = root[arr]; 113 for ( int arr = Querys[i].a - 1; arr; arr -= Lowbit(arr) ) 114 prefix_l[p[1]++] = root[arr]; 115 printf("%d ", all_val[Query( 1, num, Querys[i].c )]); 116 } else { 117 for ( int j = Querys[i].a; j <= n; j += Lowbit(j) ) 118 Insert( root[j], nums[Querys[i].a], 1, num, -1 ); 119 //将修改的结果映射到值域,并更新前缀和 120 nums[Querys[i].a] = lower_bound( all_val + 1, all_val + num + 1, Querys[i].b ) - all_val; 121 for ( int j = Querys[i].a; j <= n; j += Lowbit(j) ) 122 Insert( root[j], nums[Querys[i].a], 1, num, 1 ); 123 } 124 CLR( root, 0 ); 125 } 126 }