树链剖分
树这个结构,本身很优美,但有些涉及区间的问题用树来做就会比较别扭。有一些方便处理区间的数据结构,可以和树结合一下。比如线段树。
定义
重儿子:所有儿子中节点数最多的儿子
轻儿子:除重儿子以外的儿子
重边:链接重儿子的链
轻边:剩余的边
重链:链接重儿子的链
轻链:其余的链
思想
将树按照一条条链剖开,就可以用线段树处理了。
如图,红色为重链,黑色为轻链。
如何维护?
两遍(dfs)。
//第一遍
int son[100005],sz[100005],dep[100005],f[100005];
//重儿子 子树大小 深度 父亲节点
void dfs(int u,int fa){
f[u]=fa,sz[u]=1,dep[u]=dep[fa]+1;//记录
int ma=0;//存重儿子大小
for(int i=head[u],v;v=a[i].to,i;i=a[i].next){//遍历每条相连的边
if(v==fa) continue;//如果是父亲,跳过
dfs(v,u);
sz[u]+=sz[v];//统计子树大小
if(sz[v]>ma) son[u]=v,ma=sz[v];//不断更新重儿子
}
}
//第二遍
int top[100005],pt[100005],dfn[100005],num;
// 链头 / 该节点对应的dfs序 / 该dfs序对应的节点 / dfs序
void dfs(int u,int fa,int tp){//tp是链头
top[u]=tp,pt[u]=++num,dfn[num]=u;
if(!son[u]) return;//若没有儿子,就返回
dfs(son[u],u,tp);//递归搜索重儿子,链头不变
for(int i=head[u],v;v=a[i].to,i;i=a[i].next){
if(v==fa||v==son[u]) continue;
dfs(v,u,v);//搜索轻儿子,链头为它自己
}
}
两遍(dfs)之后,一棵树就被我们抽筋剥骨变成若干条不相干的链了。
然后我们就可以用线段树维护一些乱七八糟的东西了。
提供一道例题
题目描述
一棵树上有(n)个节点,编号分别为(1)到(n),每个节点都有一个权值(w)。
我们将以下面的形式来要求你对这棵树完成一些操作:
$I. $$CHANGE$ (u) (t): 把结点(u)的权值改为(t)
(II. QMAX) (u) (v): 询问从点(u)到点(v)的路径上的节点的最大权值
(III. QSUM) (u) (v): 询问从点(u)到点(v)的路径上的节点的权值和
注意:从点(u)到点(v)的路径上的节点包括(u)和(v)本身
输入输出格式
输入格式:
输入文件的第一行为一个整数(n),表示节点的个数。
接下来(n – 1)行,每行(2)个整数(a)和(b),表示节点(a)和节点(b)之间有一条边相连。
接下来一行(n)个整数,第(i)个整数(wi)表示节点(i)的权值。
接下来(1)行,为一个整数(q),表示操作的总数。
接下来(q)行,每行一个操作,以“(CHANGE) (u) (t)”或者(“QMAX) (u) (v”)或者(“QSUM) (u) (v”)的形式给出。
输出格式:
对于每个(“QMAX”)或者(“QSUM”)的操作,每行输出一个整数表示要求输出的结果。
输入输出样例
输入样例
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
输出样例
4
1
2
2
10
6
5
6
5
16
很显然的树链剖分
细节见代码
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
long long read(){
long long x=0;int f=0;char c=getchar();
while(c<'0'||c>'9')f|=c=='-',c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
return f?-x:x;
}
int n,q,w[100005];
struct Dier{//结构体存边
int to,next;
}a[100005];
int head[100005],cnt;
void add(int x,int y){//前向星
a[++cnt].to=y,a[cnt].next=head[x],head[x]=cnt;
}
//对树 抽筋剥骨
int son[100005],sz[100005],dep[100005],f[100005];
void dfs(int u,int fa){//第一遍dfs
f[u]=fa,sz[u]=1,dep[u]=dep[fa]+1;
int ma=0;
for(int i=head[u],v;v=a[i].to,i;i=a[i].next){
if(v==fa) continue;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>ma) son[u]=v,ma=sz[v];
}
}
int top[100005],pt[100005],dfn[100005],num;
void dfs(int u,int fa,int tp){//第二遍dfs
top[u]=tp,pt[u]=++num,dfn[num]=u;
if(!son[u]) return;
dfs(son[u],u,tp);
for(int i=head[u],v;v=a[i].to,i;i=a[i].next){
if(v==fa||v==son[u]) continue;
dfs(v,u,v);
}
}
//线段树
struct xtm{
int maxx,sum;//结构体存线段树
xtm(){maxx=-30000;}
}t[120005];//线段树要开四倍空间,被神仙嘲笑++
#define lc p<<1//左儿子
#define rc p<<1|1//右儿子
void pushup(int p){//更新上面的点
t[p].sum=t[lc].sum+t[rc].sum;//由左儿子和右儿子得来
t[p].maxx=max(t[lc].maxx,t[rc].maxx);
}
void build(int p,int l,int r){//初始化线段树
if(l==r){
//dfn[]记录的是当前序号对应的是几号节点,我们记录的左右区间及dfs序,因此直接用l找点即可
//被神仙嘲笑++
t[p].maxx=t[p].sum=w[dfn[l]];return;
}
int m=(l+r)>>1;
build(lc,l,m),build(rc,m+1,r);//线段树常规更新
pushup(p);
}
void updata(int p,int l,int r,int k,int x){//更改某点的值
if(l==r){
t[p].sum=t[p].maxx=x;return;
}
int m=(l+r)>>1;
if(m>=k) updata(lc,l,m,k,x);
else updata(rc,m+1,r,k,x);
pushup(p);
}
//求最大值
int find_max(int p,int l,int r,int L,int R){
if(l>R||r<L) return -30000;//细节,被神仙嘲笑++
if(l>=L&&r<=R) return t[p].maxx;
int m=(l+r)>>1;
return max(find_max(lc,l,m,L,R),find_max(rc,m+1,r,L,R));
}
void get_max(){
int ans=-30000,x=read(),y=read();
while(top[x]!=top[y]){//只要两个点不在一条链上,我们就可以让链头深度低的往上跳,直到两点在一条链上
//一开始用lca的我,被神仙嘲笑++
if(dep[top[x]]>dep[top[y]]) swap(x,y);
ans=max(ans,find_max(1,1,n,pt[top[y]],pt[y])),y=f[top[y]];//直接跳到链头父节点
}
if(dep[x]>dep[y]) swap(x,y);
//这里x和y已经在一条链上了,因此要判断两点的深度,而非链头深度,被神仙嘲笑++
ans=max(ans,find_max(1,1,n,pt[x],pt[y]));
printf("%d
",ans);
}
//求和
int find_sum(int p,int l,int r,int L,int R){
if(l>R||r<L) return 0;
if(l>=L&&r<=R) return t[p].sum;
int m=(l+r)>>1;
return find_sum(lc,l,m,L,R)+find_sum(rc,m+1,r,L,R);
}
void get_sum(){
int ans=0,x=read(),y=read();
while(top[x]!=top[y]){//同上
if(dep[top[x]]>dep[top[y]]) swap(x,y);
ans+=find_sum(1,1,n,pt[top[y]],pt[y]),y=f[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=find_sum(1,1,n,pt[x],pt[y]);
printf("%d
",ans);
}
int main(){
n=read();
for(int i=1,x,y;i<n;++i){//读入,连边
x=read(),y=read();
add(x,y),add(y,x);
}
for(int i=1;i<=n;++i) w[i]=read();
dfs(1,0),dfs(1,0,1),build(1,1,n);//预处理
q=read();
string s;
while(q--){
cin>>s;
if(s[0]=='C'){int u=read(),t=read();updata(1,1,n,pt[u],t);}
else if(s[1]=='M') get_max();
else if(s[1]=='S') get_sum();
}
//cout<<被神仙嘲笑;GG
return 0;
}
欢迎指正评论O(∩_∩)O~~