题目好神仙……这个叫长链剖分的玩意儿更神仙……
考虑dp,设(f[i][j])表示以(i)为根的子树中到(i)的距离为(j)的点的个数,(g[i][j])表示(i)的子树中有(g[i][j])对点深度相同,他们到LCA的距离为(d),且他们的LCA到(i)的距离为(d-j)。或者换句话来说就是以(i)为根的子树中有这么多个点对,而且没有第三个点去和这些点对匹配,第三个点不在(i)的子树中且到(i)的距离为(j),(g[i][j])表示这些点对的个数
设(u)为当前点,(v)为某一子树,那么转移方程如下
[f[u][i]+=f[v][i+1]
]
[g[u][i-1]+=g[v][i]
]
[g[u][i+1]+=f[u][i+1]*f[v][i]
]
[ans+=f[u][i-1]*g[v][i]+g[u][i+1]*f[v][i]
]
如果是原题的(nleq 5000)已经足够了,然而当(nleq 100000)的时候很明显gg了
发现状态数组的第二维实际上跟这个节点的深度有关,于是考虑用长链剖分优化。(不知道什么是长链剖分的可以看看蒟蒻的笔记)简单来说记每一个节点深度最大的儿子为它的重儿子。因为第一次转移的时候有(f[u][i]=f[v][i-1],g[u][i]=g[v][i+1]),于是可以类似于dsu on tree的思想,对于每个重儿子的信息直接继承,轻儿子暴力跑一遍。重儿子的信息可以直接用指针来达到(O(1))的转移
这个时间复杂度大概是(O(n))的,对于每个点转移的复杂度为(sum dep[v]-dep[son[u]]=sum dep[v]-dep[u]+1),然后所有点的加起来除了叶子结点都互相抵消,于是总的复杂度为(O(n))
空间复杂度也是(O(n)),因为非叶节点的空间都是由它所在重链的儿子转移来的,所以对每个叶节点开正比于此重链长度的空间即可
//minamoto
#include<bits/stdc++.h>
#define ll long long
using namespace std;
#define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
int read(){
int res,f=1;char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
const int N=1e5+5,M=1005;
int head[N],Next[N<<1],ver[N<<1],tot;
inline void add(int u,int v){ver[++tot]=v,Next[tot]=head[u],head[u]=tot;}
ll memp[N*5],*f[N],*g[N],*to=memp+5,ans;
int n,dep[N],mx[N];
void dfs(int u,int fa){
mx[u]=u;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(v!=fa){
dep[v]=dep[u]+1,dfs(v,u);
if(dep[mx[v]]>dep[mx[u]])mx[u]=mx[v];
}
}
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(v!=fa&&(mx[v]!=mx[u]||u==1)){
v=mx[v],to+=dep[v]-dep[u]+1;
f[v]=to,g[v]=(to+=1),to+=(dep[v]-dep[u])*2+1;
}
}
}
void dp(int u,int fa){
for(int i=head[u];i;i=Next[i]){
int v=ver[i];if(v==fa)continue;dp(v,u);
if(mx[v]==mx[u])f[u]=f[v]-1,g[u]=g[v]+1;
}
ans+=g[u][0],f[u][0]=1;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];if(v==fa||mx[v]==mx[u])continue;
for(int j=0;j<=dep[mx[v]]-dep[u];++j)
ans+=f[u][j-1]*g[v][j]+g[u][j+1]*f[v][j];
for(int j=0;j<=dep[mx[v]]-dep[u];++j){
g[u][j-1]+=g[v][j];
g[u][j+1]+=f[u][j+1]*f[v][j];
f[u][j+1]+=f[v][j];
}
}
}
int main(){
// freopen("testdata.in","r",stdin);
n=read();
for(int i=1,u,v;i<n;++i)u=read(),v=read(),add(u,v),add(v,u);
while(to!=memp)*to=0,--to;*to=0,++to;
dep[1]=1;dfs(1,0),dp(1,0);
printf("%lld
",ans);return 0;
}