Link:
Solution:
感觉求LCA又有了新姿势啊:$Tarjan$离线$O(n+m)$
每次递归返回时将子树和父节点合并,如果询问节点已访问过则LCA就是已合并的最高节点
这题部分分提示非常多啊
首先要将路径拆为$(S,LCA),(LCA,T)$
发现如果$(S,LCA)$能对点$x$产生贡献要满足$w[x]+dep[x]=dep[S]$
而$(LCA,T)$能对点$x$产生贡献要满足$dep[x]-w[x]=dep[T]-len$
这样用$cnt$数组维护等式右边的$dep[S]$和$dep[T]-len$的值有多少个就能快速得出有几条路径满足条件
于是可以在路径起点加入该路径特征值并在路径末尾将其消除即可
注意:
1、$LCA$处可能算了两遍,最后要逐一判断
2、要在刚进入该点时记录当前$cnt[w[x]+dep[x]]$的值否则可能会将其它子树中未走完的路径计算在内
3、此题需要从下往上统计答案,因此路径起点都要设置为深度较大的,否则不好消除不经过该点路径的贡献
Code:
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=1e6+10,ADD=3e5; int vis[MAXN],f[MAXN],st[MAXN]; int n,q,x,y,w[MAXN],res[MAXN],cnt[MAXN],head[MAXN],dep[MAXN],tot; vector<P> par[MAXN]; vector<int> in[MAXN],out[MAXN]; struct edge{int nxt,to;}e[MAXN<<2]; struct Query{int x,y,lca;}qry[MAXN]; void add(int x,int y) {e[++tot]=(edge){head[x],y};head[x]=tot;} int find(int x) {return f[x]==x?x:f[x]=find(f[x]);} void tarjan(int x,int anc) { vis[x]=1;f[x]=x; for(int i=0;i<par[x].size();i++) if(vis[par[x][i].X]) qry[par[x][i].Y].lca=find(par[x][i].X); for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=anc) { dep[e[i].to]=dep[x]+1; tarjan(e[i].to,x);f[e[i].to]=x; } } void dfs1(int x,int anc) { int cur=cnt[w[x]+dep[x]]; for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=anc) dfs1(e[i].to,x); cnt[dep[x]]+=st[x]; res[x]+=cnt[w[x]+dep[x]]-cur; for(int i=0;i<out[x].size();i++) cnt[out[x][i]]--; } void dfs2(int x,int anc) { int cur=cnt[ADD-w[x]+dep[x]]; for(int i=head[x];i;i=e[i].nxt) if(e[i].to!=anc) dfs2(e[i].to,x); //都要看成从底向上的路径 for(int i=0;i<in[x].size();i++) cnt[in[x][i]]++; res[x]+=cnt[ADD-w[x]+dep[x]]-cur; for(int i=0;i<out[x].size();i++) cnt[out[x][i]]--; } int main() { scanf("%d%d",&n,&q); for(int i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x); for(int i=1;i<=n;i++) scanf("%d",&w[i]); for(int i=1;i<=q;i++) { scanf("%d%d",&x,&y); qry[i].x=x,qry[i].y=y; par[x].pb(P(y,i));par[y].pb(P(x,i)); } tarjan(1,0); for(int i=1;i<=q;i++) out[qry[i].lca].pb(dep[qry[i].x]),st[qry[i].x]++; dfs1(1,0); memset(cnt,0,sizeof(cnt)); for(int i=1;i<=n;i++) out[i].clear(); for(int i=1;i<=q;i++) { int len=dep[qry[i].x]+dep[qry[i].y]-2*dep[qry[i].lca]; in[qry[i].y].pb(ADD+dep[qry[i].y]-len); out[qry[i].lca].pb(ADD+dep[qry[i].y]-len); } dfs2(1,0); for(int i=1;i<=q;i++) if(dep[qry[i].x]-dep[qry[i].lca]==w[qry[i].lca]) res[qry[i].lca]--; for(int i=1;i<=n;i++) printf("%d ",res[i]); return 0; }