首先来介绍一下我们需求:给你n个数,多次问你某个区间内的第k小是哪个数
主席树:
主席树的全名应该是 函数式版本的线段树。加上附带的一堆 technology。。
。。总之由于原名字太长了,而且 “主席” 两个字念起来冷艳高贵,以后全部称之为主席树好了。。
主席树的主体是线段树,准确的说,是很多棵线段树,存的是一段数字区间出现次数(所以要先离散化可能出现的数字)。举个例子,假设我每次都要求整个序列内的第 k 小,那么对整个序列构造一个线段树,然后在线段树上不断找第 k 小在当前数字区间的左半部分还是右半部分。这个操作和平衡树的 Rank 操作一样,只是这里将离散的数字搞成了连续的数字。
先假设没有修改操作:
对于每个前缀 S1…i,保存这样一个线段树 Ti,组成主席树。这样不是会 MLE 么?最后再讲。
注意,这个线段树对一条线段,保存的是这个数字区间的出现次数,所以是可以互相加减的!还有,由于每棵线段树都要保存同样的数字,所以它们的大小、形态也都是一样的!这实在是两个非常好的性质,是平衡树所不具备的。
对于询问 (i,j),我只要拿出 Tj 和 Ti-1,对每个节点相减就可以了。说的通俗一点,询问 i..j 区间中,一个数字区间的出现次数时,就是这些数字在 Tj 中出现的次数减去在 Ti-1 中出现的次数。
由于每棵线段树的大小形态都是一样的,而且初始值全都是 0,那每个线段树都初始化不是太浪费了?所以一开始不用建树。
然后是在某棵树上修改一个数字,由于和其他树相关联,所以不能在原来的树上改,必须弄个新的出来。难道要弄一棵新树?不是的,由于一个数字的更改只影响了一条从这个叶子节点到根的路径,所以只要只有这条路径是新的,另外都没有改变。比如对于某个节点,要往右边走,那么左边那些就不用新建,只要用个指针链到原树的此节点左边就可以了,这个步骤的前提也是线段树的形态一样。
所以根据我的理解以及网上的资料:我们首先使用vector进行离散化,每次加一个点就添一条边,并让(1,x(加入的点离散化后的大小))加一,最后模拟线段树查询就好
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<vector> #include<string> #include<cstdio> #include<cstring> #include<iomanip> #include<stdlib.h> #include<iostream> #include<algorithm> using namespace std; #define eps 1E-8 /*注意可能会有输出-0.000*/ #define Sgn(x) (x<-eps? -1 :x<eps? 0:1)//x为两个浮点数差的比较,注意返回整型 #define Cvs(x) (x > 0.0 ? x+eps : x-eps)//浮点数转化 #define zero(x) (((x)>0?(x):-(x))<eps)//判断是否等于0 #define mul(a,b) (a<<b) #define dir(a,b) (a>>b) typedef long long ll; typedef unsigned long long ull; const int Inf=1<<28; const ll INF=1ll<<60; const double Pi=acos(-1.0); const int Mod=1e9+7; const int Max=1e5+7; int root[Max],tot;//存根节点(n个根节点一定不同) 内存池 struct node { int lef,rig,date; } msegtr[Max*20];//开得足够大 int a[Max]; vector<int> vec; int GetId(int num)//找到离散化后原值对应的值 { return lower_bound(vec.begin(),vec.end(),num)-vec.begin()+1; } void Create(int sta,int enn,int &x,int y,int pos)//建树添边 { msegtr[++tot]=msegtr[y];//更新这条边 msegtr[tot].date++; x=tot;//增加一条边 if(sta==enn) return; int mid=dir(sta+enn,1); if(mid>=pos)//左子树添边 Create(sta,mid,msegtr[x].lef,msegtr[y].lef,pos); else Create(mid+1,enn,msegtr[x].rig,msegtr[y].rig,pos); return; } int Query(int sta,int enn,int x,int y,int k)//查询区间k大,因为满足区间减法 { if(sta==enn) { return sta; } int mid=dir(sta+enn,1); int sum=msegtr[msegtr[y].lef].date-msegtr[msegtr[x].lef].date;//左孩子区间的差 if(sum>=k) return Query(sta,mid,msegtr[x].lef,msegtr[y].lef,k); else return Query(mid+1,enn,msegtr[x].rig,msegtr[y].rig,k-sum); } int main() { int n,m,t; while(~scanf("%d %d",&n,&m)) { root[0]=0; msegtr[0].lef=msegtr[0].rig=msegtr[0].date=0; vec.clear(); tot=0; for(int i=1; i<=n; ++i) { scanf("%d",&a[i]); vec.push_back(a[i]); } sort(vec.begin(),vec.end());//离散化 vec.erase(unique(vec.begin(),vec.end()),vec.end());//去重 for(int i=1; i<=n; ++i) Create(1,n,root[i],root[i-1],GetId(a[i])); int l,r,k; for(int i=0; i<m; ++i) { scanf("%d %d %d",&l,&r,&k); printf("%d ",vec[Query(1,n,root[l-1],root[r],k)-1]);//返回原来的数字 } } return 0; }
参考:http://blog.csdn.net/metalseed/article/details/8045038