[POI2014]Hotel加强版
参考博客
题目大意
给出 (n) 个点的树,求有多少个元素互不相同的无序三元组 ((a, b, c)) 满足两两之间树上距离相等
数据范围
(nle100000)
时空限制
10sec,128MB
分析
首先设 (f(i,j)) 为 (i) 的子树中距离 (i) 为 (j) 的子树个数,设 (g(i,j)) 为 (i) 的子树中存在两点的 (lca) 与它们距离皆为 (d) ,且(lca) 距离 (i) 为 (d-j) 的方案数,那么 (f(i,j) imes g(i,j)) 就是答案,转移复杂度 (O(n^2)) 因为答案只与深度有关,用长链剖分优化至 (O(n))
Code
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
inline char nc() {
static char buf[100000], *l = buf, *r = buf;
return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void read(T & x) {
x = 0; int f = 1, ch = nc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=nc();}
while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=nc();}
x *= f;
}
typedef long long ll;
const int maxn = 100000 + 5;
const int maxe = maxn * 2;
const int maxnode = maxn * 6;
int n; ll an;
int head[maxn], ecnt;
int len[maxn], son[maxn];
ll temp[maxnode], * f[maxn], * g[maxn], * now = temp;
struct edge {
int to, nex;
edge(int to=0, int nex=0) : to(to), nex(nex) {}
} G[maxe];
inline void addedge(int u, int v ) {
G[ecnt] = edge(v, head[u]), head[u] = ecnt++;
G[ecnt] = edge(u, head[v]), head[v] = ecnt++;
}
void dfs(int u, int fa) {
son[u] = -1;
for(int i = head[u]; ~ i; i = G[i].nex) {
int v = G[i].to; if(v == fa) continue;
dfs(v, u);
len[u] = max(len[u], len[v] + 1);
if(son[u] == -1 || len[v] > len[son[u]]) {
son[u] = v;
}
}
}
int tim;
void dp(int u, int fa) {
if(~ son[u]) {
int v = son[u];
f[v] = f[u] + 1;
g[v] = g[u] - 1;
dp(v, u);
}
f[u][0] = 1; an += f[u][0] * g[u][0];
for(int i = head[u]; ~ i; i = G[i].nex) {
int v = G[i].to; if(v == fa || v == son[u]) continue;
f[v] = now, now += (len[v] + 1) << 1;
g[v] = now, now += (len[v] + 1) << 1;
dp(v, u);
for(int j = 1; j <= len[v]; ++j) an += g[v][j] * f[u][j - 1];
for(int j = 0; j <= len[v]; ++j) an += f[v][j] * g[u][j + 1];
for(int j = 1; j <= len[v]; ++j) g[u][j - 1] += g[v][j];
for(int j = 0; j <= len[v]; ++j) g[u][j + 1] += f[v][j] * f[u][j + 1];
for(int j = 0; j <= len[v]; ++j) f[u][j + 1] += f[v][j];
}
}
void solve() {
dfs(1, 0);
f[1] = now, now += (len[1] + 1) << 1;
g[1] = now, now += (len[1] + 1) << 1;
dp(1, 0);
cout << an << endl;
}
int main() {
// freopen("2.txt", "r", stdin);
read(n);
memset(head, -1, sizeof(head));
for(int i = 1; i < n; ++i) {
int u, v; read(u), read(v);
addedge(u, v);
}
solve();
return 0;
}