题意:
给出一棵树,请你支持三种操作:
(1)指定v为根节点
(2)给出u v x,将LCA(u,v)的子树加上x。
(3)求u的子树权值和。
题解:
换根的过程中维护子树,不能真的换根,尝试分类讨论根的位置,在原树上处理出当前树形下节点的子树区间。
//换根
//u v x,求LCA(u,v),把这个子树加上x
//查询u的子树点权和
//找到lca(root,u),lca(root,v),lca(u,v)
//中最深的那个
//就是当前的LCA
//然后判断LCA和根的相对位置
//如果根在LCA的子树外,LCA的当前子树就是原来的子树
//如果根在LCA的子树内,就更新除根所在的那颗子树的全部子树
//根距离LCA最近的祖先用倍增法向上跳即可,那个祖先的子树全部不更新
//查询u的子树点权和,先确定根和u的lca,如果不是u,那么就是原来的子树
//如果是u,那么就是除了根所在的那颗子树以外的全部子树
//同样用倍增法跳即可
//更新子树的点权用DFS序+线段树维护即可
//查询同理
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+100;
int n,q,a[maxn];
vector<int> g[maxn];
int father[30][maxn],h[maxn];
int dfn[maxn],size[maxn],tot,w[maxn];
void dfs (int x) {
size[x]=1;
dfn[x]=++tot;
w[dfn[x]]=a[x];
for (int y:g[x]) {
if (y==father[0][x]) continue;
father[0][y]=x;
h[y]=h[x]+1;
dfs(y);
size[x]+=size[y];
}
}
int lca (int x,int y) {
if (h[x]<h[y]) swap(x,y);
for (int i=20;i>=0;i--) if (h[x]-h[y]>>i) x=father[i][x];
if (x==y) return x;
for (int i=20;i>=0;i--) {
if (father[i][x]!=father[i][y]) {
x=father[i][x];
y=father[i][y];
}
}
return father[0][x];
}
struct node {
int l,r;
long long sum;
long long lazy;
}segTree[maxn<<2];
void build (int i,int l,int r) {
segTree[i].l=l;
segTree[i].r=r;
if (l==r) {
segTree[i].sum=w[l];
return;
}
int mid=(l+r)>>1;
build(i<<1,l,mid);
build(i<<1|1,mid+1,r);
segTree[i].sum=segTree[i<<1].sum+segTree[i<<1|1].sum;
}
void spread (int i) {
if (segTree[i].lazy) {
segTree[i<<1].sum+=1ll*(segTree[i<<1].r-segTree[i<<1].l+1)*segTree[i].lazy;
segTree[i<<1].lazy+=segTree[i].lazy;
segTree[i<<1|1].sum+=1ll*(segTree[i<<1|1].r-segTree[i<<1|1].l+1)*segTree[i].lazy;
segTree[i<<1|1].lazy+=segTree[i].lazy;
segTree[i].lazy=0;
}
}
void up (int i,int l,int r,int x) {
if (l>r) return;
if (segTree[i].l>=l&&segTree[i].r<=r) {
segTree[i].sum+=1ll*(segTree[i].r-segTree[i].l+1)*x;
segTree[i].lazy+=x;
return;
}
spread(i);
int mid=(segTree[i].l+segTree[i].r)>>1;
if (l<=mid) up(i<<1,l,r,x);
if (r>mid) up(i<<1|1,l,r,x);
segTree[i].sum=segTree[i<<1].sum+segTree[i<<1|1].sum;
}
long long query (int i,int l,int r) {
if (l>r) return 0;
if (segTree[i].l>=l&&segTree[i].r<=r) {
return segTree[i].sum;
}
spread(i);
int mid=(segTree[i].l+segTree[i].r)>>1;
long long ans=0;
if (l<=mid) ans+=query(i<<1,l,r);
if (r>mid) ans+=query(i<<1|1,l,r);
return ans;
}
int rt=1;
int main () {
scanf("%d%d",&n,&q);
for (int i=1;i<=n;i++) scanf("%d",a+i);
for (int i=1;i<n;i++) {
int x,y;
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
dfs(1);
for (int i=1;i<=20;i++) for (int j=1;j<=n;j++) father[i][j]=father[i-1][father[i-1][j]];
build(1,1,n);
while (q--) {
int op;
scanf("%d",&op);
if (op==1) {
int v;
scanf("%d",&v);
rt=v;
}
else if (op==2) {
int u,v,x;
scanf("%d%d%d",&u,&v,&x);
//for (int i=1;i<=n;i++) printf("%lld ",query(1,dfn[i],dfn[i]));
//printf("
");
int p=-1,Max=-1;
if (h[lca(u,v)]>Max) {
p=lca(u,v);
Max=h[lca(u,v)];
}
if (h[lca(rt,u)]>Max) {
p=lca(rt,u);
Max=h[lca(rt,u)];
}
if (h[lca(rt,v)]>Max) {
p=lca(rt,v);
Max=h[lca(rt,v)];
}
if (p==rt) {
up(1,1,n,x);
continue;
}
int lc=lca(p,rt);
if (lc!=p) {
up(1,dfn[p],dfn[p]+size[p]-1,x);
}
else {
int A=h[p]+1,B=rt;
for (int i=20;i>=0;i--) {
if (h[B]-A>>i) B=father[i][B];
}
up(1,1,dfn[B]-1,x);
up(1,dfn[B]+size[B],n,x);
//printf("%d %d
",dfn[B],dfn[B]+size[B]-1);
}
//for (int i=1;i<=n;i++) printf("%lld ",query(1,dfn[i],dfn[i]));
//printf("
");
}
else {
int v;
scanf("%d",&v);
int lc=lca(rt,v);
if (v==rt) {
printf("%lld
",query(1,1,n));
continue;
}
if (lc!=v) {
printf("%lld
",query(1,dfn[v],dfn[v]+size[v]-1));
}
else {
int A=h[v]+1,B=rt;
for (int i=20;i>=0;i--)
if (h[B]-A>>i) B=father[i][B];
long long ans=0;
ans+=query(1,1,dfn[B]-1);
ans+=query(1,dfn[B]+size[B],n);
printf("%lld
",ans);
}
}
}
}