树链剖分,感觉是一个很神奇的东西,但是其实并不是那样的
树链剖分其实就是一个线段树
线段树处理的是连续区间,所以当你要加的时候都是连续区间修改
所以可以用轻重链的方式将树分解成为链条,然后用线段树处理
可以很容易看到,为什么用的是dfs但不是用的是bfs呢
因为dfs保持了重链是连续的,所以可以用top[x]记录已x为节点的重链最上方,一个点也包含在重链内
若修改区间为(u,v),但是重链的祖先是一起的,所以当他们的LCA相同时,边break
所以现在u,v是连续的
所以查询(u,v)的简单路径和也就处理了
所以说线段树中可以进行的操作在树上也可以执行了
在处理一个问题
在u的子树上加w
所以修改的区间是u在线段树中的位置$(t)$ 到 $t+size(u)-1$
$size$ 记录以它为根 的子节点个数
$deep(x)$ 深度
$father(x)$ 记录父亲
$son(x)$ 它的重儿子
$top(x)$ 所在重路径的顶部节点
$seg(x)$ x在线段树中的编号
$rev(x)$ 线段树中x的位置所对应的树中节点编号
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> using namespace std; inline int read() { int f=1,ans=0;char c; while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();} while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();} return ans*f; } int n,val[1200001]; struct node{ int u,v,nex; }x[1200001]; int head[1200001],cnt; int size[1200001]; int deep[1200001]; int father[1200001]; int son[1200001]; int top[1200001]; int seg[1200001]; int rev[1200001]; int q,root,mod; void dfs1(int f,int fath) { deep[f]=deep[fath]+1; father[f]=fath; size[f]=1; for(int i=head[f];i!=-1;i=x[i].nex) { if(x[i].v==fath) continue; dfs1(x[i].v,f); size[f]+=size[x[i].v]; if(size[x[i].v]>size[son[f]]) son[f]=x[i].v; } return; } void dfs2(int f,int fath) { if(son[f]) { top[son[f]]=top[f]; seg[son[f]]=++seg[0]; rev[seg[0]]=son[f]; dfs2(son[f],f); } for(int i=head[f];i!=-1;i=x[i].nex) { if(x[i].v==fath) continue; if(top[x[i].v]) continue; top[x[i].v]=x[i].v; seg[x[i].v]=++seg[0]; rev[seg[0]]=x[i].v; dfs2(x[i].v,f); } return; } void add(int u,int v) { x[cnt].u=u,x[cnt].v=v,x[cnt].nex=head[u],head[u]=cnt++; } int ans[1200001],sum[1200001]; void build(int k,int l,int r) { if(l==r) { ans[k]=val[rev[l]]; return; } int mid=l+r>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); ans[k]=ans[k<<1]+ans[k<<1|1]; return; } void push_down(int k,int l,int r) { int mid=l+r>>1; ans[k<<1]+=sum[k]*(mid-l+1);sum[k<<1]%=mod; sum[k<<1]+=sum[k];sum[k<<1]%=mod; ans[k<<1|1]+=sum[k]*(r-mid);ans[k<<1|1]%=mod; sum[k<<1|1]+=sum[k];sum[k<<1|1]%=mod; sum[k]=0; return; } void add(int k,int l,int r,int x,int y,int v) { if(x<=l&&r<=y){ sum[k]+=v; sum[k]%=mod; ans[k]+=((r-l+1)%mod)*v%mod; ans[k]%=mod; return; } push_down(k,l,r); int mid=l+r>>1; if(x<=mid) add(k<<1,l,mid,x,y,v); if(mid<y) add(k<<1|1,mid+1,r,x,y,v); ans[k]=ans[k<<1]+ans[k<<1|1]; ans[k]%=mod; } void ask_add(int x,int y,int w) { int fx=top[x],fy=top[y]; while(fx!=fy) { if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy); add(1,1,seg[0],seg[fx],seg[x],w%mod); x=father[fx],fx=top[x]; } if(deep[x]>deep[y]) swap(x,y); add(1,1,seg[0],seg[x],seg[y],w); } int summ; int query(int k,int l,int r,int x,int y) { if(x<=l&&r<=y) return ans[k]%mod; push_down(k,l,r); int res=0,mid=l+r>>1; if(x<=mid) res+=query(k<<1,l,mid,x,y)%mod; if(mid<y) res+=query(k<<1|1,mid+1,r,x,y)%mod; return res; } int ask(int x,int y) { summ=0; int fx=top[x],fy=top[y]; while(fx!=fy) { if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy); summ+=query(1,1,seg[0],seg[fx],seg[x])%mod; x=father[fx],fx=top[x]; } if(deep[x]>deep[y]) swap(x,y); summ+=query(1,1,seg[0],seg[x],seg[y])%mod; return summ%mod; } int main() { memset(head,-1,sizeof(head)); n=read(),q=read(),root=read(),mod=read(); for(int i=1;i<=n;i++) val[i]=read(); for(int i=1;i<n;i++) { int u=read(),v=read(); add(u,v),add(v,u); } dfs1(root,0); seg[0]=1;seg[root]=1; top[root]=root; rev[1]=root; dfs2(root,0); build(1,1,seg[0]); while(q--) { int s=read(); if(s==1) { int u=read(),v=read(); int w=read(); ask_add(u,v,w%mod); } if(s==3) { summ=0; int u=read(),v=read(); add(1,1,seg[0],seg[u],seg[u]+size[u]-1,v%mod); } if(s==2) { int u=read(),v=read(); printf("%lld ",ask(u,v)%mod); } if(s==4) { int u=read(); printf("%lld ",query(1,1,seg[0],seg[u],seg[u]+size[u]-1)%mod); } } return 0; }