0. 简介
树链剖分用于处理树上问题(废话)
思想是把一棵树分成若干条链,而链上的操作显然比树上好做,这样就降低了处理难度
有多种方式能将树拆成链,因此树链剖分也有不同的种类
这篇文章写的是其中一种:轻重链剖分(最常见,有时也被直接称为树链剖分)
1. 概念
定义 \(size_x\) 表示以 \(x\) 为根的子树大小(即包含的节点个数)
称 \(x\) 的所有儿子中 \(size\) 值最大的那个为 \(x\) 的重儿子,记作 \(son_x\)
(如果有多个儿子满足条件,随便选其中一个即可)
对于每个节点 \(x\) ,将连接 \(x\) 和 \(son_x\) 的边定为重边,其它的定为轻边
将全部由重边组成的路径称为重链
这样树就被分成了若干条重链和若干条轻边
分完之后我们发现了一些有用的性质:
-
若 \(y\) 是 \(x\) 的儿子,但不是重儿子,则 \(size_y\le size_x/2.\)
反证即可,假设 \(size_y>size_x/2\) ,则 \(size_y\) 比其它儿子的 \(size\) 值之和还大,故 \(y\) 是重儿子,与条件矛盾
所以 \(size_y\le size_x/2\) -
从任意一个节点到根节点的路径上最多有 \(O(\log n)\) 条轻边。
由性质 1 可知,通过一条轻边往下走,子树大小至少减半,故显然性质成立 -
从任意一个节点到根节点的路径上最多有 \(O(\log n)\) 条重链。
重链之间是用轻边分隔的(废话)
因此重链的条数和轻边一样也是 \(O(\log n)\) 级别
题目常会让我们对两点间路径上的所有点执行某操作
由上面的性质易知,这条路径可被分成不超过 \(O(\log n)\) 条重链和轻边
我们希望能快速处理重链上的操作
对整棵树跑一遍深度优先遍历,并且优先遍历重儿子
则同一条重链对应的 dfs 序必然是连续的一段
然后重链上的操作就转化成了序列问题,用合适的数据结构维护即可
下面通过一道模板题讲一下具体怎么使用
2. 实现
首先预处理出必要的信息
/**
* fa[x]: 节点 x 的爹
* dep[x]: 节点 x 的深度
* sz[x]: 节点 x 的子树大小,即上文中的 size
* son[x]: 节点 x 的重儿子
* top[x]: 节点 x 所在重链的顶端的节点(深度最小)
* ord[x]: 节点 x 的 dfs 序
* rev[n]: dfs 序为 n 的节点,即 rev 是 ord 的逆映射
**/
void dfs(int x,int f) { // 第一遍 dfs ,处理 fa,dep,sz,son
fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
for (int i=last[x]; i; i=E[i].pre) {
int y=E[i].y;
if (y==f) continue;
dfs(y,x);
sz[x]+=sz[y];
if (sz[son[x]]<sz[y]) son[x]=y;
}
}
void dfs2(int x); // 第二遍 dfs ,处理 top,ord,rev
void work(int x,int topx) {
top[x]=topx;
++n2,ord[x]=n2,rev[n2]=x;
dfs2(x);
}
void dfs2(int x) {
if (!son[x]) return ;
work(son[x],top[x]); // 优先遍历重儿子
for (int i=last[x]; i; i=E[i].pre)
if (!top[E[i].y]) work(E[i].y,E[i].y);
}
操作 1,2 是关于两节点 \(x,y\) 之间路径的
先给出操作 1 的代码实现:
void solve1(int x,int y,int v) {
int tx=top[x],ty=top[y];
while (tx!=ty) {
if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
update(ord[tx],ord[x],v,1,n,1); // 修改
x=fa[tx],tx=top[x]; // 跳到下一条重链底端
}
if (dep[x]>dep[y]) swap(x,y);
update(ord[x],ord[y],v,1,n,1);
}
其实就是一直让深度较大的那个点往上跳
每次跳过一整条重链,并修改这条链上的信息
两点位于同一条重链上时停止,此时剩下的一小段路径都在这条链上,直接修改即可
操作 2 也没啥区别,把修改换成查询,最后返回总和即可
ll solve2(int x,int y) {
int tx=top[x],ty=top[y]; ll ans=0;
while (tx!=ty) {
if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
ans=(ans+query(ord[tx],ord[x],1,n,1))%P;
x=fa[tx],tx=top[x];
}
if (dep[x]>dep[y]) swap(x,y);
return (ans+query(ord[x],ord[y],1,n,1))%P;
}
操作 3,4 和子树相关,所以一次操作涉及的所有节点的 dfs 序是连续的一段
直接在对应区间上做修改/查询即可
void solve3(int x,int v) {
update(ord[x],ord[x]+sz[x]-1,v,1,n,1);
}
ll solve4(int x) {
return query(ord[x],ord[x]+sz[x]-1,1,n,1);
}
其中 update
和 query
的写法取决于你用什么数据结构
顺便讲一下树状数组怎么做区间修改和区间查询
对于区间修改,差分一波
设 \(c_i=a_i-a_{i-1}\) (规定 \(a_0=0\) )
若要将 \(a_l\sim a_r\) 全部加 \(x\) ,则 \(\Delta c_l=x,\Delta c_{r+1}=-x\)
对于区间查询,只需考虑如何求前缀和 \(S(x)=\sum\limits_{i=1}^x a_i\)
\(\begin{aligned} S(x)&=\sum\limits_{i=1}^x a_i\\ &=\sum\limits_{i=1}^x\sum\limits_{j=1}^ic_j\\ &=\sum\limits_{j=1}^x(x+1-j)c_j\\ &=(x+1)\sum\limits_{i=1}^xc_i-\sum\limits_{i=1}^xic_i \end{aligned}\)
设 \(d_i=ic_i\) ,则 \(S(x)\) 可以用 \(c\) 和 \(d\) 的前缀和表示
用树状数组维护 \(c\) 和 \(d\) 即可
因此本题中树状数组和线段树均可使用,时间复杂度都是 \(O(n+m\log ^2n)\) (跑不满)
3. 完整代码
线段树(使用懒标记)版本
线段树(标记永久化)和树状数组的写得太丑了 也懒得重写 就不放了(
P3384 SGT ver.
#include<stdio.h>
#include<ctype.h>
#define Tl T[p<<1]
#define Tr T[p<<1|1]
typedef long long ll;
const int N=100010;
int n,m,root,P,n2,cnt,opt,x,y,v,value[N],last[N];
int fa[N],dep[N],sz[N],son[N],top[N],ord[N],rev[N];
struct edge { int y,pre; }E[N<<1];
inline void swap(int &x,int &y) { int t=x; x=y,y=t; }
void read(int &x);
struct SGT { int len; ll tag,sum; }T[N<<2];
void build(int l,int r,int p);
void pushup(int p);
void pushdown(int p);
void update(int x,int y,int v,int l,int r,int p);
ll query(int x,int y,int l,int r,int p);
void dfs(int x,int f) {
fa[x]=f,dep[x]=dep[f]+1,sz[x]=1;
for (int i=last[x]; i; i=E[i].pre) {
int y=E[i].y;
if (y==f) continue;
dfs(y,x);
sz[x]+=sz[y];
if (sz[son[x]]<sz[y]) son[x]=y;
}
}
void dfs2(int x);
void work(int x,int topx) {
top[x]=topx;
++n2,ord[x]=n2,rev[n2]=x;
dfs2(x);
}
void dfs2(int x) {
if (!son[x]) return ;
work(son[x],top[x]);
for (int i=last[x]; i; i=E[i].pre)
if (!top[E[i].y]) work(E[i].y,E[i].y);
}
void solve1(int x,int y,int v) {
int tx=top[x],ty=top[y];
while (tx!=ty) {
if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
update(ord[tx],ord[x],v,1,n,1);
x=fa[tx],tx=top[x];
}
if (dep[x]>dep[y]) swap(x,y);
update(ord[x],ord[y],v,1,n,1);
}
ll solve2(int x,int y) {
int tx=top[x],ty=top[y]; ll ans=0;
while (tx!=ty) {
if (dep[tx]<dep[ty]) swap(x,y),swap(tx,ty);
ans=(ans+query(ord[tx],ord[x],1,n,1))%P;
x=fa[tx],tx=top[x];
}
if (dep[x]>dep[y]) swap(x,y);
return (ans+query(ord[x],ord[y],1,n,1))%P;
}
void solve3(int x,int v) {
update(ord[x],ord[x]+sz[x]-1,v,1,n,1);
}
ll solve4(int x) {
return query(ord[x],ord[x]+sz[x]-1,1,n,1);
}
int main() {
read(n),read(m),read(root),read(P);
for (int i=1; i<=n; ++i) read(value[i]);
for (int i=1; i<n; ++i) {
read(x),read(y);
E[++cnt]={y,last[x]},last[x]=cnt;
E[++cnt]={x,last[y]},last[y]=cnt;
}
dfs(root,0);
work(root,root);
build(1,n,1);
while (m--) {
read(opt),read(x);
if (opt==1) read(y),read(v),solve1(x,y,v);
if (opt==2) read(y),printf("%lld\n",solve2(x,y));
if (opt==3) read(v),solve3(x,v);
if (opt==4) printf("%lld\n",solve4(x));
}
return 0;
}
void read(int &x) {
x=0; char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) x=x*10+(ch^48),ch=getchar();
}
void build(int l,int r,int p) {
T[p].len=r-l+1;
if (l==r) {
T[p].sum=value[rev[l]];
return ;
}
int mid=(l+r>>1);
build(l,mid,p<<1);
build(mid+1,r,p<<1|1);
pushup(p);
}
inline void pushup(int p) {
T[p].sum=(Tl.sum+Tr.sum)%P;
}
inline void pushdown(int p) {
int t=T[p].tag;
Tl.tag=(Tl.tag+t)%P;
Tl.sum=(Tl.sum+t*Tl.len)%P;
Tr.tag=(Tr.tag+t)%P;
Tr.sum=(Tr.sum+t*Tr.len)%P;
T[p].tag=0;
}
void update(int x,int y,int v,int l,int r,int p) {
if (x<=l&&y>=r) {
T[p].tag=(T[p].tag+v)%P;
T[p].sum=(T[p].sum+1ll*v*T[p].len)%P;
return ;
}
pushdown(p);
int mid=(l+r>>1);
if (x<=mid) update(x,y,v,l,mid,p<<1);
if (y>mid) update(x,y,v,mid+1,r,p<<1|1);
pushup(p);
}
ll query(int x,int y,int l,int r,int p) {
if (x<=l&&y>=r) return T[p].sum;
pushdown(p);
int mid=(l+r>>1); ll ans=0;
if (x<=mid) ans=(ans+query(x,y,l,mid,p<<1))%P;
if (y>mid) ans=(ans+query(x,y,mid+1,r,p<<1|1))%P;
pushup(p);
return ans;
}