巨坑
树剖学的好啊!---sfailsth
把一段路径拆成两段,向上和S->LCA,向下LCA->T
用维护重链什么的操作搞一下。
sfailsth学长真不容易啊。。。考场上rush了4.58KB代码。。。。
常数巨大懒得优化,最慢一个点1528ms/128703KB,Orz
#include <iostream>
#include <cstring>
#include <vector>
#include <cstdio>
using namespace std;
const int N=600000;
int read() {
int ret=0,f=1; char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){ret*=10;ret+=ch-'0';ch=getchar();}
return ret*f;
}
struct Edge {
int to,nxt;
}e[N];
struct Peo {
int s,t,lca;
} p[N];
int n,m,w[N],top[N],dfn1[N],son[N],siz[N],dep[N],fa[N],tim,dfn2[N],f[N],rnk[N],ans[N],s[N],head[N],ecnt;
void add(int bg,int ed){e[++ecnt].nxt=head[bg];e[ecnt].to=ed;head[bg]=ecnt;}
bool vis[N];
int find(int x) {
return x==f[x]?x:f[x]=find(f[x]);
}
vector<int>Q[N],ID[N],S[N],T[N];
void dfs1(int x) {
siz[x]=1;
for(int i=head[x]; i; i=e[i].nxt) {
int v=e[i].to;
if(v==fa[x])continue;
fa[v]=x;
dep[v]=dep[x]+1;
dfs1(v);
siz[x]+=siz[v];
if(siz[v]>siz[son[x]]) son[x]=v;
}
}
void dfs2(int x,int qtop) {
dfn1[x]=++tim;
top[x]=qtop;
rnk[tim]=x;
if(son[x]) dfs2(son[x],qtop);
for(int i=head[x]; i; i=e[i].nxt) {
int v=e[i].to;
if(v==fa[x]||v==son[x]) continue;
dfs2(v,v);
}
dfn2[x]=tim;
}
void dfs3(int x) {
f[x]=x;
vis[x]=1;
for(int i=0; i<Q[x].size(); i++)
if(vis[Q[x][i]])
p[ID[x][i]].lca=find(Q[x][i]);
for(int i=head[x]; i; i=e[i].nxt)
if(fa[x]!=e[i].to)
dfs3(e[i].to),f[e[i].to]=x;
return;
}
void add(int x,int y,int depth) {
if(dep[x]<dep[y]) swap(x,y);
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]]) swap(x,y);
S[dfn1[top[x]]].push_back(depth);
T[dfn1[x]+1].push_back(depth);
x=fa[top[x]];
}
if(dep[x]<dep[y])swap(x,y);
S[dfn1[y]].push_back(depth);
T[dfn1[x]+1].push_back(depth);
}
void calc1() {
for(int i=1; i<=n; i++) w[i]-=dep[i];
for(int i=1; i<=m; i++) {
int now=dep[p[i].s]-dep[p[i].lca];
if(w[p[i].lca]==now-dep[p[i].lca]) ans[p[i].lca]--;
add(p[i].lca,p[i].t,now-dep[p[i].lca]);
}
for(int i=1; i<=n; i++) {
for(int j=0; j<S[i].size(); j++)
s[S[i][j]]++;
for(int j=0; j<T[i].size(); j++)
s[T[i][j]]--;
ans[rnk[i]]+=s[w[rnk[i]]];
}
}
void calc2() {
for(int i=0; i<=n; i++) S[i].clear(),T[i].clear();
memset(s,0,sizeof s);
for(int i=1; i<=n; i++) w[i]+=dep[i]<<1;
for(int i=1; i<=m; i++)
add(p[i].lca,p[i].s,dep[p[i].s]);
for(int i=1; i<=n; i++) {
for(int j=0; j<S[i].size(); j++) s[S[i][j]]++;
for(int j=0; j<T[i].size(); j++) s[T[i][j]]--;
ans[rnk[i]]+=s[w[rnk[i]]];
}
}
int main() {
n=read(),m=read();
for(int i=1,u,v; i<n; i++)
u=read(),v=read(),add(u,v),add(v,u);
for(int i=1; i<=n; i++)
w[i]=read();
for(int u,v,i=1; i<=m; i++) {
u=read(),v=read();
Q[u].push_back(v);
Q[v].push_back(u);
ID[u].push_back(i);
ID[v].push_back(i);
p[i].s=u;
p[i].t=v;
}
dfs1(1);
dfs2(1,1);
dfs3(1);
calc1();
calc2();
for(int i=1; i<=n; i++) printf("%d ",ans[i]);
}