【学习笔记】线段树浅谈
线段树,顾名思义,就是一棵树,上面的每个节点由一个“线段”构成,然后组成了一棵除去了最后一层外,是完全二叉树的树,又叫区间树,作用主要是单点修改,区间修改和区间查询,这里拿线段树的区间和来举例子
线段树的构建
线段树是一棵完全二叉树,我们可以可以按照与二叉堆相似的储存节点的方式,根节点为k,左子树为k<<1,右子树则为k<<1|1,因为树是递归定义的,所以我们开始递归建树。
1.建立结构体,用于存树
因为完全二叉树对应一个长度为a的区间,它需要的节点个数为4a,因此,我们要把树的大小开到区间长度的四倍
const int maxn=1000010;
struct tree{
int l,r;
long long dat,add,mul;
}t[maxn*4+5];
2.建立上传
因为每次建一棵树,父节点都要有相对应的数据更新,因此我们需要pushup来上传子节点的数据
void pushup(int k)
{
t[k].dat=(t[k<<1].dat+t[k<<1|1].dat)%p;
}
3.递归建树
然后用递归建树
//k代表当前节点,l代表当前节点的左端点,r代表当前节点的右端点
void build(int k,int l,int r)
{
//储存一下左端点和右端点,方便以后调用
//同时做一下延迟标记的初始化
t[k].l=l,t[k].r=r;
t[k].add=0,t[k].mul=1;
if(l==r)
{
t[k].dat=a[l];//如果到区间长度为1的时候,那么这个时候就等于对应的数字
return ;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);//递归建左子树
build(k<<1|1,mid+1,r);//递归建右子树
pushup(k);//向父节点输送左子树和右子树的数据
t[k].dat%=p;
return ;
}
这样,一棵线段树就建好了
线段树的操作
对于线段树,我们肯定不仅仅满足于建好一棵线段树,最主要的是,我们要用它做一些什么事情,来帮助我们更好地解决我们需要解决的问题。
前面说过,线段树的操作主要是点的修改和区间的数据查询,我们先说一说点的修改
1.点的修改
因为线段树是用递归定义的,因此,我们点的修改也可以用递归来进行修改这里不过多赘述,因为重要的在后面
2.区间修改
区间修改有两种修改方式,第一种是对应的某一个区间里面的数,每一个加上一个数,另一种是对应区间里面的数,每一个乘上一个数字。
对于加上一个数的方式,我们直接类似于点的修改,但是我们会发现,有一些树是并没有用到的,也就是说,我们每次在更新的时候去更新下一个节点,然后数据继续往下更新,我们就会造成很大的空间与时间的冗余,这些空间和时间都很大,我们都用得到!而且可以节省很大一部分空间!
因此,我们在这里就要使用一种特殊的操作,延迟标记,然后我们就可以节省大量的空间,以及时间,从而加快线段树的速度
大家可以看到,在之前的结构体定义中有add和mul,它们就分别是线段树加法和乘法的延迟标记
然后延迟标记在我们使用某棵树的时候需要传递下去,因此,我们与pushup相类似,我们要使用一个pushdown
void pushdown(int k)
{
//传递节点
t[k<<1].dat=(t[k<<1].dat*t[k].mul+t[k].add*(t[k<<1].r-t[k<<1].l+1))%p;
t[k<<1|1].dat=(t[k<<1|1].dat*t[k].mul+t[k].add*(t[k<<1|1].r-t[k<<1|1].l+1))%p;
//传递乘法的懒惰标记
t[k<<1].mul=(t[k].mul*t[k<<1].mul)%p;
t[k<<1|1].mul=(t[k].mul*t[k<<1|1].mul)%p;
//传递加法的懒惰标记
t[k<<1].add=(t[k<<1].add*t[k].mul+t[k].add)%p;
t[k<<1|1].add=(t[k<<1|1].add*t[k].mul+t[k].add)%p;
//还原父节点,因为子节点已经更新过了
t[k].add=0;
t[k].mul=1;
}
然后我们就可以快乐的进行区间修改了
void updata1(int k,int l,int r,int v)
{
if(l<=t[k].l&&r>=t[k].r)
{
t[k].dat=(t[k].dat*v)%p;
t[k].mul=(t[k].mul*v)%p;
return ;
}
pushdown(k);
int mid=(l+r)>>1;
if(l<=mid)
updata1(k<<1,l,r,v);
if(r>mid)
updata1(k<<1|1,l,r,v);
pushup(k);
}
void updata2(int k,int l,int r,int v)
{
if(l<=t[k].l&&r>=t[k].r)
{
t[k].dat=(t[k].dat+v*(t[k].r-t[k].l+1)%p;
t[k].mul=(t[k].add+v)%p;
return ;
}
pushdown(k);
int mid=(l+r)>>1;
if(l<=mid)
updata2(k<<1,l,r,v);
if(r>mid)
updata2(k<<1|1,l,r,v);
pushup(k);
}
3.区间查询
区间查询和区间修改实际上差不多,但是要注意返回值的时候要
模,不然会出问题
long long query(int k,int l,int r)
{
if(l<=t[k].r&&r>=t[k].r)
return t[k].dat;
pushdown(k);
long long ans=0;
int mid=(l+r)>>1;
if(l<=mid)
ans=(ans+query(k<<1,l,r))%p;
if(r>mid)
ans=(ans+query(k<<1|1,l,r))%p;
return ans%p;
}
这就是线段树的模板了
最终代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
using namespace std;
const int maxn=100010;
int n,m,p;
int a[maxn];
struct tree{
int l,r;
long long dat,add,mul;
}t[maxn*4+5];
void pushup(int k)
{
t[k].dat=(t[k<<1].dat+t[k<<1|1].dat)%p;
}
void pushdown(int k)
{
//传递节点
t[k<<1].dat=(t[k<<1].dat*t[k].mul+t[k].add*(t[k<<1].r-t[k<<1].l+1))%p;
t[k<<1|1].dat=(t[k<<1|1].dat*t[k].mul+t[k].add*(t[k<<1|1].r-t[k<<1|1].l+1))%p;
//传递乘法的懒惰标记
t[k<<1].mul=(t[k].mul*t[k<<1].mul)%p;
t[k<<1|1].mul=(t[k].mul*t[k<<1|1].mul)%p;
//传递加法的懒惰标记
t[k<<1].add=(t[k<<1].add*t[k].mul+t[k].add)%p;
t[k<<1|1].add=(t[k<<1|1].add*t[k].mul+t[k].add)%p;
//还原父节点,因为子节点已经更新过了
t[k].add=0;
t[k].mul=1;
}
//k代表当前节点,l代表当前节点的左端点,r代表当前节点的右端点
void build(int k,int l,int r)
{
//储存一下左端点和右端点,方便以后调用
//同时做一下延迟标记的初始化
t[k].l=l,t[k].r=r;
t[k].add=0,t[k].mul=1;
if(l==r)
{
t[k].dat=a[l];//如果到区间长度为1的时候,那么这个时候就等于对应的数字
return ;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);//递归建左子树
build(k<<1|1,mid+1,r);//递归建右子树
pushup(k);//向父节点输送左子树和右子树的数据
t[k].dat%=p;
return ;
}
void updata1(int k,int l,int r,int v)
{
if(l<=t[k].l&&r>=t[k].r)
{
t[k].dat=(t[k].dat*v)%p;
t[k].mul=(t[k].mul*v)%p;
return ;
}
pushdown(k);
int mid=(l+r)>>1;
if(l<=mid)
updata1(k<<1,l,r,v);
if(r>mid)
updata1(k<<1|1,l,r,v);
pushup(k);
}
void updata2(int k,int l,int r,int v)
{
if(l<=t[k].l&&r>=t[k].r)
{
t[k].dat=(t[k].dat+v*(t[k].r-t[k].l+1))%p;
t[k].mul=(t[k].add+v)%p;
return ;
}
pushdown(k);
int mid=(l+r)>>1;
if(l<=mid)
updata2(k<<1,l,r,v);
if(r>mid)
updata2(k<<1|1,l,r,v);
pushup(k);
}
long long query(int k,int l,int r)
{
if(l<=t[k].r&&r>=t[k].r)
return t[k].dat;
pushdown(k);
long long ans=0;
int mid=(l+r)>>1;
if(l<=mid)
ans=(ans+query(k<<1,l,r))%p;
if(r>mid)
ans=(ans+query(k<<1|1,l,r))%p;
return ans%p;
}
int main()
{
}