考虑分治,分别求出左侧后缀和右侧前缀的直径,即需将两者两两合并:
将直径以长度和中心点(将边拆点,使长度为偶数)的方式描述,分别记为$d$和$u$
此时,对于$(d_{1},u_{1})$和$(d_{2},u_{2})$,合并后的直径长度即$\max\{d_{1},d_{2},\frac{d_{1}+d_{2}}{2}+dis(u_{1},u_{2})\}$
若直径两端点均出自某侧内部,显然即$\max\{d_{1},d_{2}\}$
若直径两端点分别出自两侧,显然$dis(x,u)\le \frac{d}{2}$,进而
$$
dis(x,y)\le dis(x,u_{1})+dis(u_{1},u_{2})+dis(u_{2},y)\le \frac{d_{1}+d_{2}}{2}+dis(u_{1},u_{2})
$$
同时,每条直径的两端点中总有一个(相对中心点)与另一条直径不在同侧,即可取到等号
在此基础上,当确定一侧后,直径单调"偏移",即$d_{1}\rightarrow \frac{d_{1}+d_{2}}{2}+dis(u_{1},u_{2})\rightarrow d_{2}$
用双指针维护三者分界点,关于$d_{1},d_{2}$的项均易处理,下面考虑$dis(u_{1},u_{2})$
换言之,即维护一个集合$S$,支持加减元素和查询$\sum_{y\in S}d(x,y)$,进而用点分树即可
时间复杂度为$o(n\log^{2}n)$,可以通过
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 200005 4 #define ll long long 5 int n,x,y,dep[N],d[N],u[N],fa[N][20]; 6 ll ans,sum[N];vector<int>e[N]; 7 void dfs(int k,int f,int s){ 8 dep[k]=s,fa[k][0]=f; 9 for(int i=1;i<20;i++)fa[k][i]=fa[fa[k][i-1]][i-1]; 10 for(int i:e[k]) 11 if (i!=f)dfs(i,k,s+1); 12 } 13 int lca(int x,int y){ 14 if (dep[x]<dep[y])swap(x,y); 15 for(int i=19;i>=0;i--) 16 if (dep[fa[x][i]]>=dep[y])x=fa[x][i]; 17 if (x==y)return x; 18 for(int i=19;i>=0;i--) 19 if (fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; 20 return fa[x][0]; 21 } 22 int dis(int x,int y){ 23 return dep[x]+dep[y]-(dep[lca(x,y)]<<1); 24 } 25 int get_mid(int x,int y,int z){ 26 if (dep[x]<dep[y])swap(x,y); 27 for(int i=0;i<20;i++) 28 if ((z>>i)&1)x=fa[x][i]; 29 return x; 30 } 31 namespace DIVIDE{ 32 int rt,vis[N],sz[N],dep[N],fa[N],cnt1[N],cnt2[N],d[N][20]; 33 ll sum1[N],sum2[N];vector<int>v; 34 void get_sz(int k,int fa){ 35 sz[k]=1; 36 for(int i:e[k]) 37 if ((!vis[i])&&(i!=fa))get_sz(i,k),sz[k]+=sz[i]; 38 } 39 void get_rt(int k,int fa,int s){ 40 int mx=s-sz[k]; 41 for(int i:e[k]) 42 if ((!vis[i])&&(i!=fa))get_rt(i,k,s),mx=max(mx,sz[i]); 43 if (mx<=(s>>1))rt=k; 44 } 45 void get_dis(int k,int fa,int s){ 46 d[k][dep[s]]=dis(k,s); 47 for(int i:e[k]) 48 if ((!vis[i])&&(i!=fa))get_dis(i,k,s); 49 } 50 int dfs(int k,int s){ 51 get_sz(k,0),get_rt(k,0,sz[k]); 52 k=rt,vis[k]=1,dep[k]=s,get_dis(k,0,k); 53 for(int i:e[k]) 54 if (!vis[i])fa[dfs(i,s+1)]=k; 55 return k; 56 } 57 void add(int k,int p){ 58 for(int i=k;i;i=fa[i]){ 59 cnt1[i]+=p,sum1[i]+=p*d[k][dep[i]]; 60 if (fa[i])cnt2[i]+=p,sum2[i]+=p*d[k][dep[fa[i]]]; 61 } 62 } 63 ll query(int k){ 64 ll ans=0; 65 for(int i=k;i;i=fa[i]){ 66 ans+=(ll)cnt1[i]*d[k][dep[i]]+sum1[i]; 67 if (fa[i])ans-=(ll)cnt2[i]*d[k][dep[fa[i]]]+sum2[i]; 68 } 69 return ans; 70 } 71 }; 72 void solve(int l,int r){ 73 if (l==r)return; 74 int mid=(l+r>>1); 75 x=y=u[mid]=mid,d[mid]=0; 76 for(int i=mid-1;i>=l;i--){ 77 int dx=dis(i,x),dy=dis(i,y); 78 if (max(dx,dy)<d[i+1])d[i]=d[i+1]; 79 else{ 80 d[i]=max(dx,dy); 81 if (dx<dy)x=i;else y=i; 82 } 83 u[i]=get_mid(x,y,(d[i]>>1)); 84 } 85 x=y=u[mid+1]=mid+1,d[mid+1]=0; 86 for(int i=mid+2;i<=r;i++){ 87 int dx=dis(i,x),dy=dis(i,y); 88 if (max(dx,dy)<d[i-1])d[i]=d[i-1]; 89 else{ 90 d[i]=max(dx,dy); 91 if (dx<dy)x=i;else y=i; 92 } 93 u[i]=get_mid(x,y,(d[i]>>1)); 94 } 95 x=y=mid,sum[l-1]=0; 96 for(int i=l;i<=mid;i++)sum[i]=sum[i-1]+d[i]; 97 for(int i=mid+1;i<=r;i++){ 98 while ((l<=x)&&(d[x]<(d[i]+d[x]>>1)+dis(u[i],u[x])))DIVIDE::add(u[x--],1); 99 while ((x<y)&&((d[i]+d[y]>>1)+dis(u[i],u[y])<d[i]))DIVIDE::add(u[y--],-1); 100 ans+=(sum[x]+sum[y]>>1)+(ll)(y-x)*(d[i]>>1)+DIVIDE::query(u[i])+(ll)(mid-y)*d[i]; 101 } 102 while (x<y)DIVIDE::add(u[y--],-1); 103 solve(l,mid),solve(mid+1,r); 104 } 105 int main(){ 106 scanf("%d",&n); 107 for(int i=1;i<n;i++){ 108 scanf("%d%d",&x,&y); 109 e[x].push_back(i+n),e[i+n].push_back(x); 110 e[y].push_back(i+n),e[i+n].push_back(y); 111 } 112 dfs(1,0,1),DIVIDE::dfs(1,1); 113 solve(1,n),printf("%lld\n",(ans>>1)); 114 return 0; 115 }