题目链接:https://www.luogu.org/problemnew/show/P3384
题意:树链剖分模板题,但是比较坑的是要注意取模,每个可能炸int的地方都要加取模。
代码如下:
#include<cstdio> #include<algorithm> using namespace std; const int maxn=100005; //链式向前星 int cnt1,head[maxn],w[maxn],wt[maxn]; int n,m,r,Mod,res; int son[maxn],id[maxn],fa[maxn],cnt2,dep[maxn],siz[maxn],top[maxn]; struct node1{ int v,nex; }edge[maxn<<1]; void adde(int u,int v){ edge[++cnt1].v=v; edge[cnt1].nex=head[u]; head[u]=cnt1; } //线段树部分 struct node2{ int l,r,val,add; }tr[maxn<<2]; void build(int v,int l,int r){ tr[v].l=l,tr[v].r=r; if(l==r){ tr[v].val=wt[r]%Mod; return; } int mid=(l+r)>>1; build(v<<1,l,mid); build(v<<1|1,mid+1,r); tr[v].val=(tr[v<<1].val+tr[v<<1|1].val+Mod)%Mod; } void pushdown(int v){ tr[v<<1].val+=tr[v].add*(tr[v<<1].r-tr[v<<1].l+1); tr[v<<1|1].val+=tr[v].add*(tr[v<<1|1].r-tr[v<<1|1].l+1); tr[v<<1].val%=Mod; tr[v<<1|1].val%=Mod; tr[v<<1].add+=tr[v].add; tr[v<<1|1].add+=tr[v].add; tr[v].add=0; } void update(int v,int l,int r,int k){ if(l<=tr[v].l&&r>=tr[v].r){ tr[v].val+=k*(tr[v].r-tr[v].l+1); tr[v].val%=Mod; tr[v].add+=k; return; } if(tr[v].add) pushdown(v); int mid=(tr[v].l+tr[v].r)>>1; if(l<=mid) update(v<<1,l,r,k); if(r>mid) update(v<<1|1,l,r,k); tr[v].val=(tr[v<<1].val+tr[v<<1|1].val+Mod)%Mod; } void query(int v,int l,int r){ if(l<=tr[v].l&&r>=tr[v].r){ res+=tr[v].val; res%=Mod; return; } if(tr[v].add) pushdown(v); int mid=(tr[v].l+tr[v].r)>>1; if(l<=mid) query(v<<1,l,r); if(r>mid) query(v<<1|1,l,r); tr[v].val=(tr[v<<1].val+tr[v<<1|1].val+Mod)%Mod; } int qRange(int x,int y){ int ans=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); res=0; query(1,id[top[x]],id[x]); ans=(ans+res)%Mod; x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); res=0; query(1,id[x],id[y]); ans=(ans+res)%Mod; return ans; } void updRange(int x,int y,int k){ k%=Mod; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); update(1,id[top[x]],id[x],k); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); update(1,id[x],id[y],k); } int qSon(int x){ res=0; query(1,id[x],id[x]+siz[x]-1); return res; } void updSon(int x,int k){ update(1,id[x],id[x]+siz[x]-1,k); } void dfs1(int x,int f,int deep){ dep[x]=deep; fa[x]=f; siz[x]=1; int maxson=-1; for(int i=head[x];i;i=edge[i].nex){ int y=edge[i].v; if(y==f) continue; dfs1(y,x,deep+1); siz[x]+=siz[y]; if(siz[y]>maxson) son[x]=y,maxson=siz[y]; } } void dfs2(int x,int tp){ id[x]=++cnt2; wt[cnt2]=w[x]; top[x]=tp; if(!son[x]) return; dfs2(son[x],tp); for(int i=head[x];i;i=edge[i].nex){ int y=edge[i].v; if(y==fa[x]||y==son[x]) continue; dfs2(y,y); } } int main(){ scanf("%d%d%d%d",&n,&m,&r,&Mod); for(int i=1;i<=n;++i) scanf("%d",&w[i]); for(int i=0;i<n-1;++i){ int x,y; scanf("%d%d",&x,&y); adde(x,y); adde(y,x); } dfs1(r,0,1); dfs2(r,r); build(1,1,n); while(m--){ int op,x,y,z; scanf("%d",&op); if(op==1){ scanf("%d%d%d",&x,&y,&z); updRange(x,y,z); } else if(op==2){ scanf("%d%d",&x,&y); printf("%d ",qRange(x,y)); } else if(op==3){ scanf("%d%d",&x,&z); updSon(x,z); } else{ scanf("%d",&x); printf("%d ",qSon(x)); } } return 0; }