Hotels
有一个树形结构,每条边的长度相同,任意两个节点可以相互到达。选3个点。两两距离相等。有多少种方案?
1≤n≤5 000
分析
参照小塘空明的题解。
很明显到一个点距离相等的三个点两两之间距离相等。
所以我们枚举该点,对子树进行暴力统计,注意统计的顺序
时间复杂度(O(n^2))
co ll size=5e3+1;
ll n,ans,tot,mx,d[size],tmp[size],s1[size],s2[size];
ll head[size],ver[size*2],next[size*2];
void add(ll x,ll y){
ver[++tot]=y,next[tot]=head[x],head[x]=tot;
}
void dfs(ll x,ll fa){
mx=std::max(mx,d[x]);
tmp[d[x]]++;
for(ll i=head[x];i;i=next[i]){
ll y=ver[i];
if(y==fa) continue;
d[y]=d[x]+1,dfs(y,x);
}
}
int main(){
read(n);
for(ll i=1;i<n;++i){
ll x=read<ll>(),y=read<ll>();
add(x,y),add(y,x);
}
for(ll x=1;x<=n;++x){
memset(s1,0,sizeof s1);
memset(s2,0,sizeof s2);
for(ll i=head[x];i;i=next[i]){
ll y=ver[i];
mx=0,d[y]=1,dfs(y,x);
for(ll j=1;j<=mx;++j){
ans+=s2[j]*tmp[j];
s2[j]+=s1[j]*tmp[j];
s1[j]+=tmp[j];
}
for(ll j=1;j<=mx;++j) tmp[j]=0;
}
}
printf("%lld
",ans);
return 0;
}
Hotel加强版
有一个树形结构,每条边的长度相同,任意两个节点可以相互到达。选3个点。两两距离相等。有多少种方案?
数据范围:n<=100000
yyb的题解
我们先考虑一个(O(n^2))的dp,也就是原题的做法。
我们考虑一下,三个点两两的距离相同是什么情况,
-
存在一个三个点公共的LCA,所以我们在LCA统计答案即可。
-
存在一个点,使得这个点到另外两个子树中距离它为d的点以及这个点的d次祖先。
所以,设计DP状态为
-
(f[i][j])表示以(i)为根的子树中,距离当前点为(j)的点数。
-
(g[i][j])表示以(i)为根的子树中,两个点到LCA的距离为(d),并且他们的LCA到(i)的距离为(d−j)的点对数,简单来说就是(i)往其他地方走(j)步就能找到一组解。
考虑合并的时候的转移:
转移的正确性比较显然,不在多讲了,并不是这里的重点。这样子的复杂度是(O(n^2))的。
我们观察一下转移的时候有这样两步:
如果我们钦定一个儿子的话,那么这个数组是可以直接赋值的,并不需要再重复计算。
所以我们用指针来写,也就是:(f[i]=f[son]−1,g[i]=g[son]+1)。
如果整棵树是链我们发现复杂度可以做到O(n),既然如此,我们推广到树。我们进行长链剖分,每次钦定从重儿子直接转移,那么我们还需要从轻儿子进行转移。不难证明所有轻儿子都是一条重链的顶部,转移时的复杂度是重链长度。
那么,复杂度拆分成两个部分:直接从重儿子转移(O(1)),从轻儿子转移(O(∑len))。发现每个点有且仅有一个父亲,因此一条重链算且仅被一个点暴力转移,而每次转移复杂度是链长。所以全局复杂度是∑链长,也就是(O(n)),因此总复杂度就是(O(n))。
这样子写下来,发现长链剖分之后,我们的复杂度变为了线性。但是注意到复杂度证明中的一点:转移和链长相关。而链长和什么相关呢?深度。所以说对于这一类与深度相关的、可以快速合并的信息,使用长链剖分可以优化到一个非常完美的复杂度。如果需要维护的与深度无关的信息的话,或许dsu on tree是一个更好的选择。
代码
DP是的for是在用相对深度,比较简单的实现方法是之前统计重儿子的时候用高度代替深度。
然后tmp必须开到4倍是因为g数组指针给儿子的时候在前移。
co int N=1e5+1;
int n,head[N],to[N*2],nx[N*2],tot;
void add(int x,int y){to[++tot]=y,nx[tot]=head[x],head[x]=tot;}
int dep[N],md[N],son[N];
void dfs1(int x,int fa){
for(int i=head[x];i;i=nx[i]){
int y=to[i];if(y==fa) continue;
dfs1(y,x),md[x]=std::max(md[x],md[y]);
if(md[y]>md[son[x]]) son[x]=y;
}
md[x]=md[son[x]]+1;
}
ll*f[N],*g[N],tmp[N*4],*id=tmp,ans;
void dfs2(int x,int fa){
if(son[x]) f[son[x]]=f[x]+1,g[son[x]]=g[x]-1,dfs2(son[x],x);
f[x][0]=1,ans+=g[x][0];
for(int i=head[x];i;i=nx[i]){
int y=to[i]; if(y==fa||y==son[x]) continue;
f[y]=id,id+=md[y]*2,g[y]=id,id+=md[y]*2;
dfs2(y,x);
for(int j=0;j<md[y];++j){
if(j)ans+=f[x][j-1]*g[y][j];
ans+=g[x][j+1]*f[y][j];
}
for(int j=0;j<md[y];++j){
g[x][j+1]+=f[x][j+1]*f[y][j];
if(j)g[x][j-1]+=g[y][j];
f[x][j+1]+=f[y][j];
}
}
}
int main(){
read(n);
for(int i=1,x,y;i<n;++i){
read(x),read(y);
add(x,y),add(y,x);
}
dfs1(1,0);
f[1]=id,id+=md[1]*2,g[1]=id,id+=md[1]*2;
dfs2(1,0);
printf("%lld
",ans);
return 0;
}