【NOIP2019模拟赛】test
题目描述
分析
对于操作2-4,显然是树链剖分的裸题,重点是操作1
首先,操作1显然只对操作3产生影响,假设当前根为root,操作3的节点为u
如果(u=root),显然权值之和为整棵树
如果(u eq root)且(lca(u,root) eq u)那么换根操作不对子树权值和产生影响
如果(u eq root)且(lca(u,root)=u)那么权值和为整棵树减去(u)到(root)的链上最靠近(u)的一点的子树权值和
代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
struct Node
{
int l,r,val;
}tree[4*N];
struct node
{
int ver,poi;
}edge[2*N];
int n,m,a[N],deep[N],fa[N],siz[N],maxson,son[N],id[N],val[N],cnt,top[N],root,head[N],tot;
void build(int p,int L,int R)
{
tree[p].l=L; tree[p].r=R;
if( L==R )
{
tree[p].val=val[L];
return;
}
int mid=(L+R)/2;
build(p*2,L,mid);
build(p*2+1,mid+1,R);
tree[p].val=tree[2*p].val+tree[2*p+1].val;
return;
}
void change(int p,int x,int val)
{
if( tree[p].l==tree[p].r )
{
tree[p].val=val;
return;
}
int mid=(tree[p].l+tree[p].r)/2;
if( x <= mid ) change(p*2,x,val);
else change(p*2+1,x,val);
tree[p].val=tree[2*p].val+tree[2*p+1].val;
return;
}
int ask(int p,int L,int R)
{
if( tree[p].l >= L && tree[p].r <= R ) return tree[p].val;
int mid=(tree[p].l+tree[p].r)/2;
int tem=0;
if( L <= mid ) tem+=ask(p*2,L,R);
if( R > mid ) tem+=ask(p*2+1,L,R);
return tem;
}
void add(int x,int y)
{
edge[++tot].ver=y;
edge[tot].poi=head[x];
head[x]=tot;
return;
}
void Input()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
return;
}
void dfs1(int x,int f,int dep)
{
deep[x]=dep;
fa[x]=f;
siz[x]=1;
maxson=-1;
for(int i=head[x];i;i=edge[i].poi)
{
int y=edge[i].ver;
if( y==fa[x] ) continue;
dfs1(y,x,dep+1);
siz[x]+=siz[y];
if( siz[y] > maxson ) son[x]=y,maxson=siz[y];
}
return;
}
void dfs2(int x,int topf)
{
id[x]=++cnt;
val[cnt]=a[x];
top[x]=topf;
if( son[x]==0 ) return;
dfs2(son[x],topf);
for(int i=head[x];i;i=edge[i].poi)
{
int y=edge[i].ver;
if( y==son[x] || y==fa[x] ) continue;
dfs2(y,y);
}
return;
}
void Prepare()
{
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
return;
}
void deal1()
{
scanf("%d",&root);
return;
}
void deal2()
{
int x,y;
scanf("%d%d",&x,&y);
change(1,id[x],y);
return;
}
int check(int x,int y)
{
if( x==y ) return 0;
if( deep[x] >= deep[y] ) return -1;
while( deep[x] < deep[y] )
{
if( fa[y]==x ) return y;
y=fa[y];
}
return -1;
}
void deal3()
{
int x,tem=0;
scanf("%d",&x);
if( check(x,root)==0 ) tem=ask(1,1,n);
else
if( check(x,root)==-1 ) tem=ask(1,id[x],id[x]+siz[x]-1);
else tem=ask(1,1,n)-ask(1,id[check(x,root)],id[check(x,root)]+siz[check(x,root)]-1);
printf("%d
",tem);
return;
}
int chain(int x,int y)
{
int tem=0;
while( top[x]!=top[y] )
{
if( deep[top[x]] < deep[top[y]] ) swap(x,y);
tem+=ask(1,id[top[x]],id[x]);
x=fa[top[x]];
}
if( deep[x] < deep[y] ) swap(x,y);
tem+=ask(1,id[y],id[x]);
return tem;
}
void deal4()
{
int x,y,tem;
scanf("%d%d",&x,&y);
tem=chain(x,y);
printf("%d
",tem);
return;
}
void work()
{
for(int i=1;i<=m;i++)
{
int que;
scanf("%d",&que);
if( que==1 ) deal1();
else if( que==2 ) deal2();
else if( que==3 ) deal3();
else deal4();
}
return;
}
int main()
{
Input();
Prepare();
work();
return 0;
}