非常好的一道题.
假设当前要求 $ans[x]$.
先令 $x$ 为根,然后发现对于子树 $y$ 来说,令 $g[y]$ 表示距离 $y$ 最近的叶子节点.
若 $g[y] leqslant dis(x,y) $ 则 $y$ 子树的叶子中选一个就可以防止 $x$ 走到 $y$ 的子树中.
那么这个时候的答案就是 $sum [g[i] leqslant dis(x,i)][g[fa_{i}] > dis(x,fa_{i}]$.
不难发现如果 $x$ 满足条件,则 $x$ 的所有儿子也都满足条件,即满足条件的一定是一整颗子树.
由于我们只想让一颗子树贡献一次,我们就可以考虑使用 prufer 序列.
prufer 序列有:$sum deg[x]-1=n-2$,则有 $2=sum 2-deg[x]$,而由于算根节点的度数的时候没有减1,所以有 $1=sum 2-deg[x]$.
所以对于 $x$ 为根的所有 $y$ 来说,如果 $y$ 满足条件,$y$ 的贡献就是 $2-deg[y]$.
我们就将问题转化为 $ans[x]=sum [g[i] leqslant dis(i,x)](2-deg[i])$,这个用点分治+树状数组统计就行了.
code:
#include <cstdio> #include <cstring> #include <algorithm> #define ll long long #define N 70009 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int n,edges,sn,root; int hd[N],to[N<<1],nex[N<<1]; int deg[N],g[N],dep[N]; int size[N],mx[N],vis[N],ans[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } struct BIT { #define M 150000 int C[M]; int lowbit(int x) { return x&(-x); } void update(int x,int v) { for(int i=x;i<M;i+=lowbit(i)) C[i]+=v; } int query(int x) { int re=0; for(int i=x;i;i-=lowbit(i)) re+=C[i]; return re; } void clr(int x) { for(int i=x;i<N;i+=lowbit(i)) C[i]=0; } #undef M }ope; void dfs1(int x,int ff) { g[x]=N; for(int i=hd[x];i;i=nex[i]) if(to[i]!=ff) { dfs1(to[i],x); g[x]=min(g[x],g[to[i]]+1); } if(deg[x]==1) g[x]=0; } void dfs2(int x,int ff) { if(ff) g[x]=min(g[x],g[ff]+1); for(int i=hd[x];i;i=nex[i]) if(to[i]!=ff) dfs2(to[i],x); } void getroot(int x,int ff) { size[x]=1,mx[x]=0; for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(vis[y]||y==ff) continue; getroot(y,x); size[x]+=size[y]; mx[x]=max(mx[x],size[y]); } mx[x]=max(mx[x],sn-size[x]); if(mx[x]<mx[root]) root=x; } void ins(int x,int ff,int t) { dep[x]=dep[ff]+1; ope.update(N+g[x]-dep[x],t*(2-deg[x])); for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(y==ff||vis[y]) continue; ins(y,x,t); } } void upd(int x,int ff) { dep[x]=dep[ff]+1; ans[x]+=ope.query(N+dep[x]); for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(y==ff||vis[y]) continue; upd(y,x); } } void DEL(int x,int ff) { ope.clr(N+g[x]-dep[x]); for(int i=hd[x];i;i=nex[i]) { if(to[i]!=ff&&!vis[to[i]]) DEL(to[i],x); } } void calc(int x) { dep[0]=-1,ins(x,0,1); ans[x]+=ope.query(N); for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(vis[y]) continue; dep[x]=0; ins(y,x,-1); upd(y,x); ins(y,x, 1); } dep[0]=-1,ins(x,0,-1); } void solve(int x) { calc(x); vis[x]=1; for(int i=hd[x];i;i=nex[i]) { int y=to[i]; if(vis[y]) continue; root=0; sn=size[y]; getroot(y,x); solve(root); } } int main() { // setIO("input"); scanf("%d",&n); int x,y,z; for(int i=1;i<n;++i) { scanf("%d%d",&x,&y); add(x,y); add(y,x); ++deg[x],++deg[y]; } dfs1(1,0); dfs2(1,0); sn=n,root=0,mx[root]=N; getroot(1,0); solve(root); for(int i=1;i<=n;++i) if(deg[i]==1) printf("1 "); else printf("%d ",ans[i]); return 0; }