三点在树上距离相等的情况只有一种,就是以某一个点为中心,三个点到这个点的距离相等。
所以直接枚举每个点作为中心,dfs这个中心的子树,根据乘法原理统计答案即可。
时间复杂度 O(n2) (n <= 5000)
代码
#include <cstdio> #include <cstring> #include <iostream> #define N 5001 #define LL long long LL ans; int n, cnt; int head[N], to[N << 1], next[N << 1], one[N], two[N], dis[N], tmp[N]; inline int read() { int x = 0, f = 1; char ch = getchar(); for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1; for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0'; return x * f; } inline void add(int x, int y) { to[cnt] = y; next[cnt] = head[x]; head[x] = cnt++; } inline void dfs(int u) { int i, v; tmp[dis[u]]++; for(i = head[u]; i ^ -1; i = next[i]) { v = to[i]; if(!dis[v]) { dis[v] = dis[u] + 1; dfs(v); } } } int main() { int i, j, k, x, y; n = read(); memset(head, -1, sizeof(head)); for(i = 1; i < n; i++) { x = read(); y = read(); add(x, y); add(y, x); } for(i = 1; i <= n; i++) { memset(dis, 0, sizeof(dis)); memset(one, 0, sizeof(one)); memset(two, 0, sizeof(two)); dis[i] = 1; for(j = head[i]; j ^ -1; j = next[j]) { memset(tmp, 0, sizeof(tmp)); dis[to[j]] = 2; dfs(to[j]); for(k = 1; k <= n; k++) { ans += (LL)two[k] * tmp[k]; two[k] += one[k] * tmp[k]; one[k] += tmp[k]; } } } printf("%lld ", ans); return 0; }