Code:
#include<cstdio> #include<algorithm> #include<iostream> #include<cstring> using namespace std; typedef long long ll; const int maxn=100000+2; ll mod;int n; int head[maxn],nex[maxn*2],to[maxn*2]; int p[maxn],dep[maxn],siz[maxn],son[maxn],top[maxn];int cnt,cnt2; //树剖 ll sumv[maxn*4]; int lazy[maxn*4],A[maxn],st[maxn],ed[maxn],val[maxn];//线段树 void addedge(int u,int v){nex[++cnt]=head[u],head[u]=cnt,to[cnt]=v;} void dfs1(int u,int fa,int cur){ p[u]=fa,dep[u]=cur,siz[u]=1; for(int i=head[u];i;i=nex[i]) if(to[i]!=fa) { dfs1(to[i],u,cur+1); siz[u]+=siz[to[i]]; if(son[u]==-1||siz[to[i]]>siz[son[u]])son[u]=to[i]; } } void dfs2(int u,int tp){ top[u]=tp,A[u]=++cnt2,st[u]=cnt2; if(son[u]!=-1)dfs2(son[u],tp); for(int i=head[u];i;i=nex[i]) if(to[i]!=p[u]&&to[i]!=son[u])dfs2(to[i],to[i]); ed[u]=cnt2; } void down(int L,int R,int o){ if(lazy[o]){ int mid=(L+R)/2; lazy[o*2]+=lazy[o],lazy[o*2+1]+=lazy[o]; sumv[o*2]+=(mid-L+1)*lazy[o],sumv[o*2+1]+=(R-mid)*lazy[o]; lazy[o]=0; if(sumv[o*2]>=mod)sumv[o*2]%=mod; if(sumv[o*2+1]>=mod)sumv[o*2+1]%=mod; if(lazy[o*2]>=mod)lazy[o*2]%=mod; if(lazy[o*2+1]>=mod)lazy[o*2+1]%=mod; } } void build(int L,int R,int o,int arr[]) { if(L==R){sumv[o]=arr[L];if(sumv[o]>=mod)sumv[o]%=mod;return;} int mid=(L+R)/2; build(L,mid,o*2,arr); build(mid+1,R,o*2+1,arr); sumv[o]=sumv[o*2]+sumv[o*2+1]; if(sumv[o]>=mod)sumv[o]%=mod; } void update(int l,int r,int k,int L,int R,int o){ if(l<=L&&r>=R){ lazy[o]+=k,sumv[o]+=(R-L+1)*k; if(lazy[o]>=mod)lazy[o]%=mod; if(sumv[o]>=mod)sumv[o]%=mod; return; } int mid=(L+R)/2; down(L,R,o); if(l<=mid)update(l,r,k,L,mid,o*2); if(r>mid)update(l,r,k,mid+1,R,o*2+1); sumv[o]=sumv[o*2]+sumv[o*2+1]; if(sumv[o]>=mod)sumv[o]%=mod; } ll query(int l,int r,int L,int R,int o) { if(l<=L&&r>=R)return sumv[o]; int mid=(L+R)/2; down(L,R,o); ll ret=0; if(l<=mid)ret+=query(l,r,L,mid,o*2); if(r>mid)ret+=query(l,r,mid+1,R,o*2+1); sumv[o]=sumv[o*2]+sumv[o*2+1]; if(sumv[o]>=mod)sumv[o]%=mod; if(ret>=mod)ret%=mod; return ret; } void up(int x,int y,int del){ while(top[x]!=top[y]){ if(dep[top[y]]<dep[top[x]]){update(A[top[x]],A[x],del,1,n,1);x=p[top[x]];} else {update(A[top[y]],A[y],del,1,n,1);y=p[top[y]];} } if(dep[x]<dep[y])update(A[x],A[y],del,1,n,1); else update(A[y],A[x],del,1,n,1); } ll look_up(int x,int y) { ll _sum=0; while(top[x]!=top[y]){ if(dep[top[y]]<dep[top[x]]){ _sum+=query(A[top[x]],A[x],1,n,1); if(_sum>=mod)_sum%=mod; x=p[top[x]]; } else { _sum+=query(A[top[y]],A[y],1,n,1); if(_sum>=mod)_sum%=mod; y=p[top[y]]; } } if(dep[x]<dep[y])_sum+=query(A[x],A[y],1,n,1); else _sum+=query(A[y],A[x],1,n,1); if(_sum>=mod)_sum%=mod; return _sum; } int main() { int m,r; scanf("%d%d%d",&n,&m,&r);scanf("%lld",&mod); for(int i=1;i<=n;++i)scanf("%d",&val[i]); for(int i=1;i<n;++i){int a,b;scanf("%d%d",&a,&b);addedge(a,b);addedge(b,a);} memset(son,-1,sizeof(son)); dfs1(r,-1,1); dfs2(r,r); for(int i=1;i<=n;++i)siz[A[i]]=val[i]; build(1,n,1,siz); while(m--) { int op;scanf("%d",&op); if(op==1){ int x,y,z;scanf("%d%d%d",&x,&y,&z);up(x,y,z); } if(op==2){ int x,y;scanf("%d%d",&x,&y);printf("%lld ",look_up(x,y)); } if(op==3){ int x,z;scanf("%d%d",&x,&z); update(st[x],ed[x],z,1,n,1); } if(op==4){ int x;scanf("%d",&x); printf("%lld ",query(st[x],ed[x],1,n,1)); } } return 0; }