良心的可持久化线段树教程
在O~I~中辗转了千~百天,终于可以随手写出各种打标记的、不打标记的、一维的、二维的、求最值的、求和的、求第k大的线段树之后——
我们来学习可持久化线段树吧!
什么是可持久化线段树?
可持久化线段树最大的特点是:可以访问历史版本。例如,我对线段树进行了1000次修改操作,突然问你第233次修改之后某个区间的区间和是多少——这个问题可持久化线段树就可以正常地回答出来。这个性质有许多奇妙的应用。
那么如何实现这样的一棵线段树呢?
想象一棵普通的线段树,我们要对它进行单点修改,需要修改(log n)个点。每次修改的时候,我们丝毫不修改原来的节点,而是在它旁边新建一个节点,把原来节点的信息(如左右儿子编号、区间和等)复制到新节点上,并对新节点进行修改。
那么如何查询历史版本呢?只需记录每一次修改对应的新根节点编号(根据上面描述的操作,根节点每次一定会新建一个的),每次询问从对应的根节点往下查询就好了。
可持久化线段树的代码实现
我们以维护区间和的可持久化线段树为例,下面实现的这棵树支持:单点修改;单点查询。
要定义的数组:
int idx; //index,记录目前一共建过多少节点
int sum[M], lson[M], rson[M]; //区间和、左儿子、右儿子
int root[N]; //每次修改对应的根节点编号
假设这道题一开始序列全是0,首先我们把一棵空的树建出来:
void build(int &k, int l, int r){
//k传的是地址,这样在这一层函数中修改k就可以直接修改上一层的lson或rson了
k = ++idx; //为新节点编号
if(l == r) return; //一定要在创建完新节点之后再return
int mid = (l + r) >> 1;
build(lson[k], l, mid);
build(rson[k], mid + 1, r);
}
接下来实现修改操作,把位置p上的数增加x。
//old是这个位置原来的节点,k是当前要创建的新节点(的地址)
void change(int old, int &k, int l, int r, int p, int x){
k = ++idx; //修改的时候要创建新点
lson[k] = lson[old], rson[k] = rson[old];
sum[k] = sum[old] + x; //先把原来节点的信息复制过来,顺便修改区间和
if(l == r) return; //仍然要记得先建点后return
int mid = (l + r) >> 1;
if(p <= mid) change(lson[k], lson[k], l, mid, p, x);
else change(rson[k], rson[k], mid + 1, r, p, x);
}
下面进行区间和查询(这个函数和普通线段树几乎完全一样)。
int query(int k, int l, int r, int ql, int qr){
if(ql <= l && qr >= r) return sum[k];
int mid = (l + r) >> 1, ans = 0;
if(ql <= mid) ans += query(lson[k], l, mid, ql, qr);
if(qr > mid) ans += query(rson[k], mid + 1, r, ql, qr);
return ans;
}
然后一棵可持久化线段树就完成了!
可持久化线段树注意事项:
- 取地址符
- 先建点,后if(l==r)return
- 空间要开((n + Qlog n))那么大
可持久化线段树的应用:区间第k大
例题:51nod 1175 区间中第k大的数
给出一个序列,每次询问一个区间[l, r]和数字k,问区间中第k大的数是多少。
这个经典的问题还可以用划分树解决——这玩意我还写过,毫无疑问地写跪了。
好在可持久化线段树写这个也非常方便好写,只是需要先离散化处理一下,并且对询问离线。
下面介绍的解法巧妙利用了可持久化线段树支持查找历史版本的特点:
首先把询问按照右端点排序。
建立一棵可持久化线段树,维护每个数出现的次数(若数范围大,需要离散化)。
从左往右扫一遍序列,在可持久化线段树中给对应的数+1。
当处理到某个询问(l, r)的右端点时,发现:对于任意一个数,(右端点加入后的线段树上的值 - 左端点加入前的线段树上的值)就是区间中这个数出现的次数。那么在这棵“减出来的”线段树上进行类似约瑟夫问题的“求第k大数”查询即可。
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define INF 0x3f3f3f3f
#define space putchar(' ')
#define enter putchar('
')
template <class T>
bool read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
else if(c == EOF) return 0;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
return 1;
}
template <class T>
void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 50005, M = 2000005;
int n, Q, ans[N], a[N], lst[N], cnt, num[N]; //用于离散化
int idx, sum[M], lson[M], rson[M], root[N];
struct Query{
int id, l, r, x;
bool operator < (const Query &b) const{
return r < b.r;
}
} q[N];
void build(int &k, int l, int r){
k = ++idx;
if(l == r) return;
int mid = (l + r) >> 1;
build(lson[k], l, mid);
build(rson[k], mid + 1, r);
}
void change(int old, int &k, int l, int r, int p, int x){
k = ++idx;
lson[k] = lson[old], rson[k] = rson[old];
sum[k] = sum[old] + x; //先复制一波之前的节点,顺便修改区间和
if(l == r) return;
int mid = (l + r) >> 1;
if(p <= mid) change(lson[k], lson[k], l, mid, p, x);
else change(rson[k], rson[k], mid + 1, r, p, x);
}
int query(int new_k, int old_k, int l, int r, int x){ //查询第x小
if(l == r) return l;
int mid = (l + r) >> 1, sum_right = sum[rson[new_k]] - sum[rson[old_k]];
if(sum_right >= x)
return query(rson[new_k], rson[old_k], mid + 1, r, x);
else
return query(lson[new_k], lson[old_k], l, mid, x - sum_right);
}
int find(int x){
return lower_bound(num + 1, num + cnt + 1, x) - num;
}
int main(){
read(n);
for(int i = 1; i <= n; i++)
read(lst[i]), a[i] = lst[i];
sort(lst + 1, lst + n + 1);
for(int i = 1; i <= n; i++)
if(i == 1 || lst[i] != lst[i - 1])
num[++cnt] = lst[i];
build(root[0], 1, cnt);
read(Q);
for(int i = 1; i <= Q; i++){
q[i].id = i;
read(q[i].l), q[i].l++;
read(q[i].r), q[i].r++;
read(q[i].x);
}
sort(q + 1, q + Q + 1);
for(int i = 1, j = 1; i <= n; i++){
change(root[i - 1], root[i], 1, cnt, find(a[i]), 1);
while(q[j].r == i)
ans[q[j].id] = query(root[i], root[q[j].l - 1], 1, cnt, q[j].x), j++;
}
for(int i = 1; i <= Q; i++)
printf("%d
", num[ans[i]]);
return 0;
}
博主蒟蒻,欢迎指正!