题目大意:已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
解题思路:树链剖分。
剖完进行dfs遍历,并记录每个节点的dfs序(优先遍历重链)。
可以发现任何一条重链的dfs序都是连续的,并且任何一棵子树中所有节点的dfs序也是连续的。
我们用线段树来维护每个dfs序对应的节点的信息。
对于操作1和2,让两个节点往链顶跳,每条一次在线段树中更新或查询链顶到原来节点的dfs序的信息。
对于操作3和4,由于一棵子树中所有节点dfs序有序,直接修改或查询即可。若一棵子树的根节点的dfs序是x,子树大小是sz,那子树最大的一个dfs序是x+sz-1。
由于树剖重链和轻链数量都是log级的,加上线段树时间复杂度,总时间复杂度$O(mlog^2 n)$。
C++ Code:
#include<cstdio> #include<cctype> #include<cstring> #define ll long long #define N 120500 int n,m,rt,p,a[N],head[N],cnt,sz[N],fa[N],dep[N],son[N],dfn[N],idx; int aa[N],L,R,c,top[N]; ll ans; struct SegmentTreeNode{ ll sum,add; }d[N<<2]; struct edge{ int to,nxt; }e[N<<1]; inline int readint(){ char c=getchar(); for(;!isdigit(c);c=getchar()); int d=0; for(;isdigit(c);c=getchar()) d=(d<<3)+(d<<1)+(c^'0'); return d; } void dfs(int now){ sz[now]=1; for(int i=head[now];i;i=e[i].nxt) if(!dep[e[i].to]){ dep[e[i].to]=dep[now]+1; fa[e[i].to]=now; dfs(e[i].to); sz[now]+=sz[e[i].to]; if(son[now]==0||sz[e[i].to]>sz[son[now]])son[now]=e[i].to; } } void dfs2(int now){ dfn[now]=++idx; if(son[now])top[son[now]]=top[now],dfs2(son[now]); for(int i=head[now];i;i=e[i].nxt) if(dep[now]<dep[e[i].to]&&e[i].to!=son[now]) dfs2(top[e[i].to]=e[i].to); } inline void update(int l,int o){ int lft=o<<1; int rgt=lft|1; d[lft].add=(d[lft].add+d[o].add)%p; d[rgt].add=(d[rgt].add+d[o].add)%p; d[lft].sum=(d[lft].sum+d[o].add*((l+1)>>1))%p; d[rgt].sum=(d[rgt].sum+d[o].add*(l>>1))%p; d[o].add=0; } void build(int l,int r,int o){ if(l==r){ d[o]=(SegmentTreeNode){aa[l],0}; return; } int mid=(l+r)>>1; build(l,mid,o<<1); build(mid+1,r,o<<1|1); d[o].sum=(d[o<<1].sum+d[o<<1|1].sum)%p; d[o].add=0; } void add_T(int l,int r,int o){ if(L<=l&&r<=R){ d[o].add=(d[o].add+c)%p; d[o].sum=(d[o].sum+c*(r-l+1))%p; return; } int mid=(l+r)>>1; update(r-l+1,o); if(L<=mid)add_T(l,mid,o<<1); if(mid<R)add_T(mid+1,r,o<<1|1); d[o].sum=(d[o<<1].sum+d[o<<1|1].sum)%p; } void que_T(int l,int r,int o){ if(L<=l&&r<=R){ ans=(ans+d[o].sum)%p; return; } int mid=(l+r)>>1; update(r-l+1,o); if(L<=mid)que_T(l,mid,o<<1); if(mid<R)que_T(mid+1,r,o<<1|1); } void add_1(int x,int y){ for(;top[x]!=top[y];) if(dep[top[x]]>=dep[top[y]]){ L=dfn[top[x]],R=dfn[x]; add_T(1,n,1); x=fa[top[x]]; }else{ L=dfn[top[y]],R=dfn[y]; add_T(1,n,1); y=fa[top[y]]; } if(dep[x]<=dep[y]){ L=dfn[x],R=dfn[y]; add_T(1,n,1); }else{ L=dfn[y],R=dfn[x]; add_T(1,n,1); } } void que_1(int x,int y){ for(;top[x]!=top[y];) if(dep[top[x]]>=dep[top[y]]){ L=dfn[top[x]],R=dfn[x]; que_T(1,n,1); x=fa[top[x]]; }else{ L=dfn[top[y]],R=dfn[y]; que_T(1,n,1); y=fa[top[y]]; } if(dep[x]<=dep[y]){ L=dfn[x],R=dfn[y]; que_T(1,n,1); }else{ L=dfn[y],R=dfn[x]; que_T(1,n,1); } } int main(){ memset(dep,0,sizeof dep); memset(head,0,sizeof head); memset(son,0,sizeof son); cnt=idx=0; n=readint(),m=readint(),rt=readint(),p=readint(); for(int i=1;i<=n;++i)a[i]=readint()%p; for(int i=1;i<n;++i){ int u=readint(),v=readint(); e[++cnt]=(edge){v,head[u]}; head[u]=cnt; e[++cnt]=(edge){u,head[v]}; head[v]=cnt; } dep[top[rt]=rt]=1; fa[rt]=rt; dfs(rt); dfs2(rt); for(int i=1;i<=n;++i)aa[dfn[i]]=a[i]; build(1,n,1); while(m--){ int x=readint(),l,r; switch(x){ case 1: l=readint(),r=readint(),c=readint(); add_1(l,r); break; case 2: ans=0; l=readint(),r=readint(); que_1(l,r); printf("%d ",(int)(ans%p)); break; case 3: L=readint(),c=readint(); R=dfn[L]+sz[L]-1; L=dfn[L]; add_T(1,n,1); break; case 4: ans=0; L=readint(); R=dfn[L]+sz[L]-1; L=dfn[L]; que_T(1,n,1); printf("%d ",(int)(ans%p)); break; } } return 0; }