[POI2014]HOT-Hotels 加强版 题解
题目描述:
给出一棵有 n 个点的树,求有多少组点 (i,j,k) 满足 i,j,k 两两之间的距离都相等。
Solution
很明显是一个树上计数类的问题,我们可以考虑树形DP。这道题状态的选取我认为并不太好想,但是也见过几次类似的状态可能是我太菜了想了很久......
我们首先应该会想到设\(f_{u,j}\)为在\(u\)的子树中满足\(dis(x,u)=j\)的\(x\)的个数。考虑这个转移方程如何转移,以及如何计算答案。
似乎并不好直接转移......也并不好计算答案......
为帮助转移,考虑多记录一些条件。再设\(g_{u,j}\)表示满足在\(u\)的子树中满足\(dis(lca(x,y),x)=dis(lca(x,y),y)=dis(lca(x,y),u)+j\)的\((x,y)\)点对的个数。感觉这个方程能够很好的与题目要求建立联系,尝试去转移。
对于\(f_{u,j}\),显然很好直接转移:\(f_{u,j}=\sum\limits_{v\in son} f_{v,j-1}\)。
对于\(g_{u,j}\),考虑转移方程。
首先,可以想到:\(g_{u,j}=\sum\limits_{v\in son} f_{v,j+1}\)。恰好和\(f_{u,j}\)相反。
其次,对于两个点,我们可以用这两个点转移到\(g_{i,j}\),即:\(g_{u,j}=\sum\limits_{x,y\in son}f_{x,j-1}\times f_{y,j-1}\)
那么,剩下的便是统计答案,答案显然首先加上\(g_{u,0}\)。再考虑别的情况,我们显然可以在\(g_{u,j}\)上加一条链,即\(f_{v,j-1}\),使其构成一个合法的答案,于是,我们又会有这一个式子:\(ans=\sum \limits _{v \in son}g_{u,j} \times f_{v,j-1}\)。
那么,这个DP的复杂度显然是\(O(n^2)\)的,可以通过这一道题[P3565 [POI2014]HOT-Hotels](P3565 [POI2014]HOT-Hotels)。但是仍然不足以通过这一道。
下面介绍一种优化方式——长链剖分!
其预处理和重链剖分差不多,只是把重儿子改为长儿子(子树深度最大的儿子),它可以优化这种有一维和深度相关的DP。
其基本思路是这样的:对于长儿子,直接指针继承,对于其他儿子则暴力合并,这样总复杂度是长链总长级别的,即\(O(n)\)。
Code
#include<bits/stdc++.h>
#define LL long long
using namespace std;
int n,cnt;
int dep[200005],son[200005];
int head[200005],to[400005],Next[400005];
LL ans,p[400005],*f[200005],*g[200005],*r=p;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-')f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
x=(x<<1)+(x<<3)+ch-'0';
ch=getchar();
}
return x*f;
}
inline void add(int u,int v){
to[++cnt]=v;Next[cnt]=head[u];head[u]=cnt;
}
void dfs(int u,int fa){
dep[u]=1;
for(register int i=head[u];i;i=Next[i]){
int v=to[i];
if(v!=fa){
dfs(v,u);
dep[u]=max(dep[u],dep[v]+1);
son[u]=dep[v]>dep[son[u]]?v:son[u];
}
}
return;
}
void DP(int u,int fa){
if(son[u]){
f[son[u]]=f[u]+1;
g[son[u]]=g[u]-1;
DP(son[u],u);
}
f[u][0]=1;ans+=g[u][0];
for(register int i=head[u];i;i=Next[i]){
int v=to[i];
if(v==fa||v==son[u])
continue;
f[v]=r;r+=dep[v]<<1;
g[v]=r;r+=dep[v]<<1;
DP(v,u);
for(register int j=0;j<dep[v];++j){
ans+=g[u][j+1]*f[v][j];
if(j)
ans+=f[u][j-1]*g[v][j];
}
for(register int j=0;j<dep[v];++j){
g[u][j+1]+=f[u][j+1]*f[v][j];
if(j)
g[u][j-1]+=g[v][j];
f[u][j+1]+=f[v][j];
}
}
return;
}
int main(){
n=read();
for(register int i=2;i<=n;++i){
int u,v;
u=read();v=read();
add(u,v);add(v,u);
}
dfs(1,0);
f[1]=r;r+=dep[1]<<1;
g[1]=r;r+=dep[1]<<1;
DP(1,0);
printf("%lld\n",ans);
return 0;
}