一直以为自己当时是TLE了,但是再看发现居然WA?
然后把数组扩大一倍,就A掉了。QaQ
没什么好说的。一段路径分成两段考虑,上升的一段深度+时间是定值,下降的一段深度-时间是定值,然后打标记统计即可。
发现大概是统计数组因为深度+时间太大炸掉了。
现在想想,当时没有对拍,真是后怕。
#include <cstdio> #include <vector> #include <cstring> #include <iostream> #include <algorithm> using namespace std; int read() { int ret=0,f=1; char ch=getchar(); while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();} while (ch>='0'&&ch<='9'){ret*=10;ret+=ch-'0';ch=getchar();} return ret*f; } int n,m; int h[600005],to[600002],ne[600002],en=0; void add(int a,int b) { to[en]=b; ne[en]=h[a]; h[a]=en++; } int root=1,dep[600005],fa[600005],son[600005],siz[600005]; vector <int> st[600005]; vector <int> ed[600005]; void dfs1(int k) { siz[k]=1; for (int i=h[k];i>=0;i=ne[i]) if (to[i]!=fa[k]){ fa[to[i]]=k; dep[to[i]]=dep[k]+1; dfs1(to[i]); siz[k]+=siz[to[i]]; if (siz[son[k]]<siz[to[i]]) son[k]=to[i]; } } int pos[600005],top[600005],li[600005],cnt=0; void dfs2(int k,int tp) { top[k]=tp; cnt++; pos[k]=cnt; li[cnt]=k; if (!son[k]) return ; if (son[k]) dfs2(son[k],tp); for (int i=h[k];i>=0;i=ne[i]) if (to[i]!=son[k]&&to[i]!=fa[k]) dfs2(to[i],to[i]); return ; } int w[600005]; vector <int> que[600005]; vector <int> bel[600005]; struct node{int s,t,lca;}q[600005]; int vis[600005],f[600005]; int gf(int k) { if (f[k]==k) return k; else return f[k]=gf(f[k]); } void dfs3(int k) { f[k]=k; vis[k]=1; for (int i=0;i<que[k].size();++i) if (vis[que[k][i]]) q[bel[k][i]].lca=gf(que[k][i]); for (int i=h[k];i>=0;i=ne[i]) if (fa[k]!=to[i]){ dfs3(to[i]); f[to[i]]=k; } return ; } void lca(){dfs3(1);} int s[600005],ans[600005]; void add(int a,int b,int fn) { if (dep[a]<dep[b]) swap(a,b); while (top[a]!=top[b]) { if (dep[top[a]]<dep[top[b]]) swap(a,b); st[pos[top[a]]].push_back(fn); ed[pos[a]+1].push_back(fn); a=fa[top[a]]; } if (dep[a]<dep[b]) swap(a,b); st[pos[b]].push_back(fn); ed[pos[a]+1].push_back(fn); } void cal1() { for (int i=1;i<=n;++i) w[i]-=dep[i]; for (int i=1;i<=m;++i) { int at; at=dep[q[i].s]-dep[q[i].lca]; if (w[q[i].lca]==at-dep[q[i].lca]) ans[q[i].lca]--; add(q[i].lca,q[i].t,at-dep[q[i].lca]); } for (int i=1;i<=n;++i) { for (int j=0;j<st[i].size();++j) s[st[i][j]]++; for (int j=0;j<ed[i].size();++j) s[ed[i][j]]--; ans[li[i]]+=s[w[li[i]]]; } } void cal2() { for (int i=0;i<=n+5;++i) st[i].clear(),ed[i].clear(); memset(s,0,sizeof s); for (int i=1;i<=n;++i) w[i]+=2*dep[i]; for (int i=1;i<=m;++i) { int at; at=dep[q[i].s]-dep[q[i].lca]; add(q[i].lca,q[i].s,at+dep[q[i].lca]); } for (int i=1;i<=n;++i) { for (int j=0;j<st[i].size();++j) s[st[i][j]]++; for (int j=0;j<ed[i].size();++j) s[ed[i][j]]--; ans[li[i]]+=s[w[li[i]]]; } } void out() { for (int i=1;i<n;++i)printf("%d ",ans[i]); printf("%d",ans[n]); } int main() { memset(h,-1,sizeof h); n=read(); m=read(); for (int i=1;i<n;++i) { int a,b; a=read(); b=read(); add(a,b); add(b,a); } for (int i=1;i<=n;++i) w[i]=read(); for (int i=1;i<=m;++i) { int a,b; a=read(); b=read(); que[a].push_back(b); que[b].push_back(a); bel[a].push_back(i); bel[b].push_back(i); q[i].s=a;q[i].t=b; } dfs1(root); dfs2(root,root); lca(); cal1(); cal2(); out(); return 0; }