题面
题解
三点距离两两相等要满足以下条件:
有一个相同的$ ext{LCA}$
所以如果存在一个点,使得另外两个点在它子树中,距离为$d$,且$ ext{LCA}$距这个点为$d$,
那么这三个点就距离两两相等。
设$f[i][j]$表示以$i$为根的子树中距$i$为$j$的点的数量
$g[i][j]$表示以$i$为根的子树中,两个点到$LCA$的距离为$d$,并且他们的$LCA$到$i$的距离为$d−j$的点对数。
看到合并时的转移:
$$ ans ext{+=}g[i][0] \ ans ext{+=}g[i][j]*f[son][j-1] \ f[i][j]=sum f[son][j-1] \ g[i][j]=sum g[son][j+1] $$
复杂度为$ ext{O}(n^2)$
而转移方程下面两个柿子
$$ f[i][j]=sum f[son][j-1] \ g[i][j]=sum g[son][j+1] $$
如果我们选择一个转移代价最高的儿子直接继承过来的话,复杂度就会大大降低
而转移代价最高的儿子显然就是长链剖分后那个点的重儿子
那么用指针描述就是:
$$ f[i]=f[heavy[i]] - 1,; g[i] = g[heavy[i]] + 1 $$
这样可以快速继承。
于是我们就将整棵树进行长链剖分,钦定从重儿子转移,其他儿子重新计算
于是从重儿子转移$O(1)$,从轻儿子转移$O($链长$)$
总时间复杂度$O(n+sum$链长$)=O(n)$
同时因为使用指针,所以常数小出现了一些玄学问题
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#define RG register
inline int read()
{
int data = 0, w = 1;
char ch = getchar();
while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if(ch == '-') w = -1, ch = getchar();
while(ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}
const int maxn(100010);
struct edge { int next, to; } e[maxn << 1];
int head[maxn], e_num, n;
inline void add_edge(int from, int to)
{
e[++e_num] = (edge) {head[from], to};
head[from] = e_num;
}
int dep[maxn], heavy[maxn], maxdep[maxn];
void dfs(int x, int fa)
{
// maxdep[x] = dep[x] = dep[fa] + 1;
for(RG int i = head[x]; i; i = e[i].next)
{
int to = e[i].to; if(to == fa) continue;
dfs(to, x); maxdep[x] = std::max(maxdep[x], maxdep[to]);
if(maxdep[to] > maxdep[heavy[x]]) heavy[x] = to;
}
maxdep[x] = maxdep[heavy[x]] + 1;
// 玄学问题,将第28行注释去掉并注释掉上一行,交到BZOJ就会RE
}
long long *f[maxn], *g[maxn], pool[maxn << 2], *id = pool, ans;
void calc(int x, int fa)
{
if(heavy[x]) f[heavy[x]] = f[x] + 1,
g[heavy[x]] = g[x] - 1, calc(heavy[x], x);
f[x][0] = 1, ans += g[x][0];
for(RG int i = head[x]; i; i = e[i].next)
{
int to = e[i].to; if(to == fa || to == heavy[x]) continue;
f[to] = id; id += maxdep[to] << 1; g[to] = id; id += maxdep[to] << 1;
calc(to, x);
for(RG int j = 0; j < maxdep[to]; j++)
{
if(j) ans += f[x][j - 1] * g[to][j];
ans += g[x][j + 1] * f[to][j];
}
for(RG int j = 0; j < maxdep[to]; j++)
{
g[x][j + 1] += f[x][j + 1] * f[to][j];
if(j) g[x][j - 1] += g[to][j];
f[x][j + 1] += f[to][j];
}
}
}
int main()
{
n = read();
for(RG int i = 1, a, b; i < n; i++)
a = read(), b = read(), add_edge(a, b), add_edge(b, a);
dfs(1, 0); f[1] = id; id += maxdep[1] << 1; g[1] = id; id += maxdep[1] << 1;
calc(1, 0); printf("%lld
", ans);
return 0;
}