智商不够数据结构来凑
常规操作就是将树上的一条路径((s,t))拆分成((s,lca))和((lca,t))来看
首先考虑一下上行路径
显然对于点(x)来说,只有(dep[s]=dep[x]+a[x])且(lca)在(x)子树外面
好像非常难算的样子,我们考虑减掉(lca)在子树内部的情况
于是我们先不考虑(lca)在(x)子树外面这一限制条件,我们直接求出所有的的(dep[x]=dep[x]+a[x]),也就是求在(dep[x]+a[x])这一深度上有几个起点
主席树上树可以随便做
之后减掉(lca)在子树内部的情况,于是我们可以将起点的深度作为权值,减掉子树内部的所有权值为(dep[x]+a[x])的(lca)就好了
之后是下行路径
显然这个时候需要满足的条件是(L-a[x]+dep[x]-1=dep[t]),其中(L)表示路径长度
同时还得满足(lca)在子树外面
我们可以把(dep[t]-L)作为权值,之后还是子树查询就好了
所以一共需要四个主席树
同时还有几个比较坑的地方
-
主席树的值域开大一些,大概(2*n)就足够了
-
尽管我们要减掉(lca)在子树内部的情况,但是(lca)恰好在(x)时是可以的,于是还要加回来
-
lca如果算了两次那么就需要特判一下
代码
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#define re register
#define maxn 300000
#define max(a,b) ((a)>(b)?(a):(b))
struct E
{
int v,nxt;
}e[maxn<<1];
int n,m,num,maxdep,tot;
int fa[maxn],top[maxn],deep[maxn],son[maxn],sum[maxn],to[maxn],_to[maxn];
int a[maxn],head[maxn],v1[maxn];
int ans[maxn];
std::vector<int> v2[maxn],v3[maxn],v4[maxn];
inline int read()
{
char c=getchar();
int x=0;
while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9')
x=(x<<3)+(x<<1)+c-48,c=getchar();
return x;
}
inline void add_edge(int x,int y)
{
e[++num].v=y;
e[num].nxt=head[x];
head[x]=num;
}
void dfs1(int x)
{
sum[x]=1;
int maxx=-1;
for(re int i=head[x];i;i=e[i].nxt)
if(!deep[e[i].v])
{
deep[e[i].v]=deep[x]+1;
maxdep=max(maxdep,deep[e[i].v]);
fa[e[i].v]=x;
dfs1(e[i].v);
sum[x]+=sum[e[i].v];
if(sum[e[i].v]>maxx) son[x]=e[i].v,maxx=sum[e[i].v];
}
}
void dfs2(int x,int topf)
{
top[x]=topf;
to[x]=++tot;
_to[tot]=x;
if(!son[x]) return;
dfs2(son[x],topf);
for(re int i=head[x];i;i=e[i].nxt)
if(deep[e[i].v]>deep[x]&&son[x]!=e[i].v) dfs2(e[i].v,e[i].v);
}
inline int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) std::swap(x,y);
x=fa[top[x]];
}
if(deep[x]>deep[y]) return y;
return x;
}
struct Segment_Tree
{
int l[maxn*29],r[maxn*29],d[maxn*29];
int rt[maxn];
int cnt;
int build(int x,int y)
{
int root=++cnt;
if(x==y) return root;
int mid=x+y>>1;
l[root]=build(x,mid),r[root]=build(mid+1,y);
return root;
}
int change(int pre,int x,int y,int pos,int val)
{
int root=++cnt;
d[root]=d[pre]+val;
if(x==y) return root;
l[root]=l[pre],r[root]=r[pre];
int mid=x+y>>1;
if(pos<=mid) l[root]=change(l[pre],x,mid,pos,val);
else r[root]=change(r[pre],mid+1,y,pos,val);
return root;
}
int query(int p,int x,int y,int pos)
{
if(x==y) return d[p];
int mid=x+y>>1;
if(pos<=mid) return query(l[p],x,mid,pos);
return query(r[p],mid+1,y,pos);
}
}up,_up,down,_down;
int main()
{
n=read(),m=read();
int x,y;
for(re int i=1;i<n;i++)
x=read(),y=read(),add_edge(x,y),add_edge(y,x);
deep[1]=1;
dfs1(1),dfs2(1,1);
for(re int i=1;i<=n;i++)
a[i]=read();
int s,t;
for(re int i=1;i<=m;i++)
{
s=read(),t=read();
int lca=LCA(s,t);
int L=deep[s]+deep[t]-2*deep[lca]+1;
if(deep[s]-deep[lca]==a[lca]) ans[lca]--;
v1[s]++;
v2[lca].push_back(deep[s]);
v3[t].push_back(deep[t]-L+n);
v4[lca].push_back(deep[t]-L+n);
}
up.rt[0]=up.build(1,n*2);
_up.rt[0]=_up.build(1,n*2);
down.rt[0]=down.build(0,n*2);
_down.rt[0]=_down.build(0,n*2);
for(re int i=1;i<=n;i++)
{
up.rt[i]=up.change(up.rt[i-1],1,n*2,deep[_to[i]],v1[_to[i]]);
int pre=_up.rt[i-1];
for(re int j=0;j<v2[_to[i]].size();j++)
pre=_up.change(pre,1,n*2,v2[_to[i]][j],1);
_up.rt[i]=pre;
pre=down.rt[i-1];
for(re int j=0;j<v3[_to[i]].size();j++)
pre=down.change(pre,0,n*2,v3[_to[i]][j],1);
down.rt[i]=pre;
pre=_down.rt[i-1];
for(re int j=0;j<v4[_to[i]].size();j++)
pre=_down.change(pre,0,n*2,v4[_to[i]][j],1);
_down.rt[i]=pre;
}
for(re int i=1;i<=n;i++)
{
int x=to[i],y=to[i]+sum[i]-1;
ans[i]+=up.query(up.rt[y],1,n*2,deep[i]+a[i])-up.query(up.rt[x-1],1,n*2,deep[i]+a[i]);
ans[i]-=_up.query(_up.rt[y],1,n*2,deep[i]+a[i])-_up.query(_up.rt[x-1],1,n*2,deep[i]+a[i]);
ans[i]+=_up.query(_up.rt[x],1,n*2,deep[i]+a[i])-_up.query(_up.rt[x-1],1,n*2,deep[i]+a[i]);
ans[i]+=down.query(down.rt[y],0,n*2,deep[i]-a[i]-1+n)-down.query(down.rt[x-1],0,n*2,deep[i]-a[i]-1+n);
ans[i]-=_down.query(_down.rt[y],0,n*2,deep[i]-a[i]-1+n)-_down.query(_down.rt[x-1],0,n*2,deep[i]-a[i]-1+n);
ans[i]+=_down.query(_down.rt[x],0,n*2,deep[i]-a[i]-1+n)-_down.query(_down.rt[x-1],0,n*2,deep[i]-a[i]-1+n);
}
for(re int i=1;i<=n;i++)
printf("%d ",ans[i]);
return 0;
}