树链剖分:
用于解决一系列维护静态树上信息的问题。这些问题看起来非常像一些区间操作搬到了树上。
(例如:一棵带权树,需要维护修改权值操作以及从$u$到$v$简单路径上的权值和)
树链剖分就是通过某种策略(一般是轻、重边剖分)将原树链划分成若干条链,每条链相当于一个序列,此时就可以用区间数据结构(一般是线段树)维护这些链。
需要维护的值:
$f(x)$:$x$在树中的父亲。
$dep(x)$:$x$在树中的深度。
$siz(x)$:$x$的子树大小。
$son(x)$:$u$的重儿子:在$u$的所有儿子中$siz$值最大的儿子,$u ightarrow v$为重边。
($u$的轻儿子:在$u$的所有儿子中除了重儿子以外的儿子,$u ightarrow v$为轻边。)
$top(x)$:$x$所在重路径的顶部节点。
$seg(x)$:$x$在线段树中的位置(下标)。
$rnk(x)$:线段树中$x$位置对应的树中节点编号,即有$rnk(seg(x))=x$。
轻重边的一些性质:
1、如果$u ightarrow v$为轻边,则$siz(v)<=siz(u)/2$。
证明:反证法,若存在$siz(v)>siz(u)/2$且存在$siz(v_0)>siz(v)$,那么$siz(v)+siz(v_0)>siz(u)$,即子节点的$siz$和大于父节点的$siz$。
2、从根到任何点$u$的路径上轻边的条数不超过$log(N)$。
证明:由1可知从根到$u$的路径上每经过一条轻边,当前子树的节点个数至少会少$frac{1}{2}$,所以至多减少$log(N)$次$siz$值为0,到达叶节点。
3、从根到任何点$u$的路径上轻边、重边的条数均不超过$log(N)$。
证明:每条重链的起点和终点都连接一条轻边,由2可知轻边条数不超过$log(N)$,所以重链条数也不超过$log(N)$。
实现步骤:
1、一遍$dfs$得到前4个值,再一遍$dfs$将树的节点重新排序,使一条重链上的点$dfs$序连续。
2、使用线段树维护新树的$dfs$序序列,查询时沿重链走到两点的$lca$并计算答案。
模板题目:loj10138
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> using namespace std; #define MAXN 100005 #define MAXM 500005 #define INF 0x7fffffff #define ll long long int hd[MAXN],to[MAXN<<1],top[MAXN]; int A[MAXN],nxt[MAXN<<1],cnt,tot; int f[MAXN],siz[MAXN],son[MAXN]; int seg[MAXN],rnk[MAXN],dep[MAXN]; struct node{int l,r,sum,mx;}tr[MAXN<<2]; char str[10]; inline int read(){ int x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } inline void add(int u,int v){ to[++cnt]=v,nxt[cnt]=hd[u]; hd[u]=cnt;return; } inline void pushup(int k){ tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx); tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum; return; } inline void dfs1(int u,int fa,int d){ dep[u]=d;f[u]=fa;siz[u]=1; for(int i=hd[u];i;i=nxt[i]){ int v=to[i]; if(v==fa) continue; dfs1(v,u,d+1); siz[u]+=siz[v]; if(siz[v]>siz[son[u]]) son[u]=v; } return; } inline void dfs2(int u,int fa,int tp){ top[u]=tp;seg[u]=++tot;rnk[tot]=u; if(son[u]) dfs2(son[u],u,tp); for(int i=hd[u];i;i=nxt[i]){ int v=to[i]; if(v==fa || v==son[u]) continue; dfs2(v,u,v); } return; } inline void build(int L,int R,int k){ tr[k].l=L,tr[k].r=R; if(L==R){ tr[k].mx=tr[k].sum=A[rnk[L]]; return; } int mid=(L+R)>>1; build(L,mid,k<<1); build(mid+1,R,k<<1|1); pushup(k);return; } inline void update(int x,int y,int k){ if(tr[k].l==tr[k].r){ tr[k].mx=tr[k].sum=y; return; } int mid=(tr[k].l+tr[k].r)>>1; if(x<=mid) update(x,y,k<<1); else update(x,y,k<<1|1); pushup(k);return; } inline int qmx(int L,int R,int k){ if(L<=tr[k].l && tr[k].r<=R) return tr[k].mx; int mid=(tr[k].l+tr[k].r)>>1; if(L<=mid && R>mid) return max(qmx(L,R,k<<1),qmx(L,R,k<<1|1)); else if(R<=mid) return qmx(L,R,k<<1); else return qmx(L,R,k<<1|1); } inline int qsum(int L,int R,int k){ if(L<=tr[k].l && tr[k].r<=R) return tr[k].sum; int mid=(tr[k].l+tr[k].r)>>1; if(L<=mid && R>mid) return qsum(L,R,k<<1)+qsum(L,R,k<<1|1); else if(R<=mid) return qsum(L,R,k<<1); else return qsum(L,R,k<<1|1); } inline int solve1(int u,int v){ int ans=-INF; while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]]) swap(u,v); ans=max(ans,qmx(seg[top[u]],seg[u],1)); u=f[top[u]]; } if(dep[u]<dep[v]) swap(u,v); ans=max(ans,qmx(seg[v],seg[u],1)); return ans; } inline int solve2(int u,int v){ int ans=0; while(top[u]!=top[v]){ if(dep[top[u]]<dep[top[v]]) swap(u,v); ans+=qsum(seg[top[u]],seg[u],1); u=f[top[u]]; } if(dep[u]<dep[v]) swap(u,v); ans+=qsum(seg[v],seg[u],1); return ans; } int main(){ int N=read(); for(int i=1;i<N;i++){ int u=read(),v=read(); add(u,v);add(v,u); } for(int i=1;i<=N;i++) A[i]=read(); dfs1(1,0,1);dfs2(1,0,1);build(1,N,1); int M=read(); while(M--){ cin>>str;int x=read(),y=read(); if(str[0]=='C') update(seg[x],y,1); else if(str[1]=='M') printf("%d ",solve1(x,y)); else printf("%d ",solve2(x,y)); } return 0; }