根据期望的线性性,我们算出每个点期望被计算次数,然后进行累加.
考虑点 $x$ 对点 $y$ 产生了贡献,那么说明 $(x,y)$ 之间的点中 $x$ 是第一个被删除的.
这个期望就是 $frac{1}{dis(x,y)+1}$,所以我们只需求 $sum_{i=1}^{n}sum_{j=1}^{n}frac{1}{dis(i,j)+1}$ 即可.
然后这个直接求是求不出来的,所以需要用点分治+FFT来算树上每种距离都出现了多少次.
code:
#include <bits/stdc++.h> using namespace std; #define N 500003 #define ll long long #define setIO(s) freopen(s".in","r",stdin) const double pi=acos(-1); ll ans[N]; int edges,root,sn,n,mxdep; int size[N],mx[N],hd[N],to[N<<1],nex[N<<1],vis[N]; inline void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } struct cpx { double x,y; cpx(double a=0,double b=0) { x=a,y=b; } cpx operator+(const cpx b) { return cpx(x+b.x,y+b.y); } cpx operator-(const cpx b) { return cpx(x-b.x,y-b.y); } cpx operator*(const cpx b) { return cpx(x*b.x-y*b.y,x*b.y+y*b.x); } }A[N],B[N]; void fft(cpx *a,int len,int flag) { int i,j,k,mid; for(i=k=0;i<len;++i) { if(i>k) swap(a[i],a[k]); for(j=len>>1;(k^=j)<j;j>>=1); } for(mid=1;mid<len;mid<<=1) { cpx wn(cos(pi/mid), flag*sin(pi/mid)),x,y; for(i=0;i<len;i+=mid<<1) { cpx w(1,0); for(j=0;j<mid;++j) { x=a[i+j],y=w*a[i+j+mid]; a[i+j]=x+y; a[i+j+mid]=x-y; w=w*wn; } } } if(flag==-1) for(int i=0;i<len;++i) a[i].x/=(double)len; } void getroot(int u,int ff) { size[u]=1,mx[u]=0; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; getroot(v,u); size[u]+=size[v]; mx[u]=max(mx[u], size[v]); } mx[u]=max(mx[u], sn-size[u]); if(mx[u]<mx[root]) root=u; } void dfs(int u,int ff,int d) { ++A[d].x; mxdep=max(mxdep,d); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; dfs(v,u,d+1); } } void calc(int u,int d) { mxdep=0; dfs(u,0,d==1?0:1); int len=1; while(len<=(mxdep+mxdep+2)) len<<=1; fft(A,len,1); for(int i=0;i<len;++i) A[i]=A[i]*A[i]; fft(A,len,-1); for(int i=0;i<min(len,n);++i) ans[i]+=(ll)(A[i].x+0.1)*d; for(int i=0;i<len;++i) A[i].x=A[i].y=0; } void solve(int u) { vis[u]=1; calc(u,1); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; calc(v,-1); } for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; root=0,sn=size[v],getroot(v,u); solve(root); } } int main() { // setIO("input"); int i,j; scanf("%d",&n); for(i=1;i<n;++i) { int x,y; scanf("%d%d",&x,&y); ++x,++y; add(x,y),add(y,x); } mx[0]=sn=n; getroot(1,0); solve(root); double tmp=0.0; for(int i=0;i<n;++i) { tmp+=(double) ans[i]/(i+1); } printf("%.4f ",tmp); return 0; }