首发于摸鱼世界&更好的阅读体验
到现在也只会照着std打板子..
虽然这样,
毒树链剖分还是一个非常优雅的算法。
前置芝士:(DFS),线段树
树链剖分可以把树上的区间操作通过把树剖成一条条链,利用线段树等数据结构进行维护,从而达到(O(nlogn))的优秀时间复杂度。
比如这样的操作:
在一棵树上,将(x)到(y)路径上点的点权加上(w),并要求支持查询两个点(x,y)路径间的点权和。
乍一看,两个操作都很简单。修改操作可以用树上差分(O(1))乱搞,静态查询可以用(LCA)完成。
但是合起来就没有办法了:每次查询之前都需要(O(n))预处理,数据略大直接(T)飞。
于是树剖出场了。
区间修改&查询是线段树的强项,但是它只能对一段连续的区间进行查询。于是我们需要想办法让树上需要操作的路径变成一段连续的区间。
引入一个概念:重儿子,也就是一个节点的儿子中(size)最大的。连接到重儿子的边即为重边
重儿子组成的链,就是重链。
比如在这棵树中,连续的红边组成的就是一条条重链。我们用(top[u])记录节点(u)所在重链的顶端。特别地,没有被重边连接的节点,(top[u]=u),即它们所在重链的顶端就是自身。注意到,当(u)是一条重链的顶端((top[u]=u))时,它的父节点一定在另一条重链上。
始终记住我们的目标:把在树上区间操作转化为在一段连续的区间进行操作。
考虑如何用(DFS)给树上的每个节点在区间内找到一个合适的位置。我们发现,从根节点出发,优先走重边,这样的(dfs)序似乎有点特殊。
例如上图,优先走重边的(dfs)序为:(124798356)。很显然,这样的(dfs)序满足同一条重链上的点(dfs)序连续。所以用线段树维护的,就是重链上的信息。
这样操作之后,我们可以做到的是:(O(logn))对一条重链上的信息区间修改,区间查询。
对于两个节点(u,v),我们可以通过不断地跳重链,直到两个节点在同一条重链上。这个是很好实现的,因为只需要跳到(fa[top[u]]),就到了一条新的重链。
代码实现仅树剖部分是不麻烦的。我们需要维护的信息有(dep)(节点深度),(fa)(父节点),(son)(重儿子),(sz)(子树节点数,用来判重儿子),这些可以用一次(dfs)完成。
void dfs1(int u,int f,int d)//fa,dep,son,sz
{
fa[u]=f;
dep[u]=d;
sz[u]=1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v!=f)
{
dfs1(v,u,d+1);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
}
接下来,就需要把这棵树每个节点压到线段树维护的序列的一个位置了。就像上文说的一样,按照优先重边的(dfs)序压入线段树即可。于是记录一个(id[i])表示原树中节点(i)对应的线段树中的下标。(rk[i])反过来记录线段树中下标为(i)的原数编号。
由于预处理了父节点,所以(dfs2)传参只需要(u)(当前节点)和(t)(当前重链顶端节点)。在遍历儿子之前先(dfs2(son[u],t)),因为(u)和(u)的重儿子在同一条重链上。接下来才遍历轻(非重)儿子(v),但是传参为(dfs2(v,v)),因为(v)就是新的一条重链的起点。
void dfs2(int u,int t)//top,id,rk
{
top[u]=t;
id[u]=++tot;
rk[tot]=u;
if(!son[u])return;
dfs2(son[u],t);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v!=fa[u]&&v!=son[u])
dfs2(v,v);
}
}
再回到最开始的问题:
在一棵树上,将(x)到(y)路径上点的点权加上(w),并要求支持查询两个点(x,y)路径间的点权和。
答案就显得很明了了。
如果是查询,先保证(dep[x]>dep[y]),然后就和(LCA)类似的,利用重链加速:每次把([top[x],x])这条重链的和累加到答案上,再使(x)跳到另一条重链上,即(x=fa[top[x]]),直到(x,y)在同一条重链上,再把两个点之间的信息统计累加一下即可。
int getsum(int x,int y)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
sum=0;
asksum(1,id[top[x]],id[x]);
(res+=sum)%=mod;
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
sum=0;
asksum(1,id[x],id[y]);
(res+=sum)%=mod;
return res;
}
修改同理。
于是我们发现,虽然我们采用了优先重边的(dfs)序,但它毕竟遍历的都是自己的儿子节点。所以...还可以支持子树操作。因为一棵子树在重边优先的(dfs)序中编号也是连续的。并且这个编号很容易算,因为我们维护了一个(sz)信息。所以树中(x)节点的子树对应的就是线段树维护的([id[x],id[x]+sz[x]-1])这个区间。
于是还是板子一般的线段树区间修改&查询。
可以注意到线段树部分基本没讲,因为每个人写线段树的方法可能不太一样,蒟蒻我分享的只是树剖的思想。
另外,为什么树剖每次操作是(O(logn))呢?利用线段树的子树操作自然是(O(logn)),剩下的就是那个像(LCA)一样的跳重链。
证明:从任意节点向根节点跳重链,经过的重链和轻边(非重边)都是(log)级别的。
考虑到每走一条轻边,子树大小至少翻倍,否则这就不是条轻边了。于是经过的轻边就最多为(log_2 n)条。而重链和轻边的交替出现的,所以数量也在这个级别。
于是每次操作就只有(O(logn))的时间复杂度。
以下是代码
#include<bits/stdc++.h>
#define int long long
#define ls (k<<1)
#define rs (k<<1|1)
using namespace std;
const int N=1e5+10;
struct node
{
int l,r,w,f;
}t[N<<2];
int a[N];
int n,m,r,mod;
int sum;
int head[N<<1],to[N<<1],nxt[N<<1],cnt;
int sz[N],fa[N],dep[N],son[N];
int top[N],id[N],rk[N],tot;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void add(int u,int v)
{
cnt++;
to[cnt]=v;
nxt[cnt]=head[u];
head[u]=cnt;
}
void dfs1(int u,int f)
{
fa[u]=f;
sz[u]=1;
dep[u]=dep[f]+1;
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==f)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
return;
}
void dfs2(int u,int t)
{
top[u]=t;
id[u]=++tot;
rk[tot]=u;
if(!son[u])return;
dfs2(son[u],t);
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v!=fa[u]&&v!=son[u])dfs2(v,v);//新的重链
}
}
void build(int k,int l,int r)
{
t[k].l=l,t[k].r=r;
if(l==r)
{
t[k].w=a[rk[l]];
return;
}
int m=l+r>>1;
build(ls,l,m);
build(rs,m+1,r);
t[k].w=t[ls].w+t[rs].w;
return;
}
void down(int k)
{
t[ls].w+=(t[ls].r-t[ls].l+1)*t[k].f;
t[rs].w+=(t[rs].r-t[rs].l+1)*t[k].f;
t[ls].f+=t[k].f;
t[rs].f+=t[k].f;
t[k].f=0;
}
void addsum(int k,int x,int y,int p)
{
int l=t[k].l,r=t[k].r;
if(x<=l&&r<=y)
{
t[k].w+=(r-l+1)*p;
t[k].f+=p;
return;
}
down(k);
int m=l+r>>1;
if(x<=m)addsum(ls,x,y,p);
if(y>m)addsum(rs,x,y,p);
t[k].w=t[ls].w+t[rs].w;
return;
}
void asksum(int k,int x,int y)
{
int l=t[k].l,r=t[k].r;
if(x<=l&&r<=y)
{
sum+=t[k].w;
return;
}
down(k);
int m=l+r>>1;
if(x<=m)asksum(ls,x,y);
if(y>m)asksum(rs,x,y);
t[k].w=t[ls].w+t[rs].w;
return;
}
//-----------------------------
int getsum(int x,int y)
{
int res=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
sum=0;
asksum(1,id[top[x]],id[x]);
(res+=sum)%=mod;
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
sum=0;
asksum(1,id[x],id[y]);
(res+=sum)%=mod;
return res;
}
void update(int x,int y,int p)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
addsum(1,id[top[x]],id[x],p);
x=fa[top[x]];
}
if(id[x]>id[y])swap(x,y);
addsum(1,id[x],id[y],p);
return;
}
signed main()
{
n=read(),m=read(),r=read(),mod=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
add(x,y),add(y,x);
}
dfs1(r,0);
dfs2(r,r);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int x,y,z;
int opt=read();
if(opt==1)
{
x=read(),y=read(),z=read();
update(x,y,z);
}
if(opt==2)
{
x=read(),y=read();
printf("%lld
",getsum(x,y)%mod);
}
if(opt==3)
{
x=read(),z=read();
addsum(1,id[x],id[x]+sz[x]-1,z);
}
if(opt==4)
{
x=read();
sum=0;asksum(1,id[x],id[x]+sz[x]-1);
printf("%lld
",sum%mod);
}
}
return 0;
}
代码的确是长,也不算容易调,但是真正妙的是利用轻重链的思想进行的化树为链。
感谢阅读。