传送门:https://codeforces.com/problemset/problem/161/D
题意:
求树上点对距离恰好为k的点对个数
题解:
与poj1741相似
把点分治的模板改一下即可,我们依然是求得一个dep数组,然后根据这个dep数组来更新两点间的距离,由于k的范围只有500,所以我们可以直接开一个500的数组来统计两点间距离的数量
代码:
#include <set>
#include <map>
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
typedef pair<int, int> pii;
typedef unsigned long long uLL;
#define ls rt<<1
#define rs rt<<1|1
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define bug printf("*********
")
#define FIN freopen("input.txt","r",stdin);
#define FON freopen("output.txt","w+",stdout);
#define IO ios::sync_with_stdio(false),cin.tie(0)
#define debug1(x) cout<<"["<<#x<<" "<<(x)<<"]
"
#define debug2(x,y) cout<<"["<<#x<<" "<<(x)<<" "<<#y<<" "<<(y)<<"]
"
#define debug3(x,y,z) cout<<"["<<#x<<" "<<(x)<<" "<<#y<<" "<<(y)<<" "<<#z<<" "<<z<<"]
"
const int maxn = 3e5 + 5;
const int INF = 0x3f3f3f3f;
struct EDGE {
int v, w, nxt;
} edge[maxn << 1];
int head[maxn], tot;
void add_edge(int u, int v, int w) {
edge[tot].v = v;
edge[tot].w = w;
edge[tot].nxt = head[u];
head[u] = tot++;
}
int sz[maxn], son[maxn], dep[maxn], vis[maxn];
int Maxt, root, Allnode, cnt;
LL ans;
int n, k;
void get_root(int u, int fa) {
sz[u] = 1;
for(int i = head[u]; i != -1; i = edge[i].nxt) {
int v = edge[i].v;
if(!vis[v] && v != fa) {
get_root(v, u);
sz[u] += sz[v];
}
}
int tmp = max(sz[u] - 1, Allnode - sz[u]);
if(Maxt > tmp) Maxt = tmp, root = u;
}
void dfs(int u, int fa, int len, int dis) {
dep[++cnt] = dis;
if(dis >= len) return;
for(int i = head[u]; i != -1; i = edge[i].nxt) {
int v = edge[i].v;
if(!vis[v] && v != fa) {
dfs(v, u, len, dis + 1);
}
}
}
LL cal(int rt, int fa, int len) {
if(len <= 0) return len == 0;
cnt = 0;
dfs(rt, fa, len, 0);
LL res = 0;
int num[505]{};
for(int i = 1; i <= cnt; i++) {
num[dep[i]]++;
}
for(int i = 1; i <= cnt; i++) {
res += num[len - dep[i]];
}
return res;
}
void divide(int rt) {
vis[rt] = 1;
// debug1(ans);
ans += cal(rt, 0, k);
for(int i = head[rt]; i != -1; i = edge[i].nxt) {
int v = edge[i].v;
if(!vis[v]) {
ans -= cal(v, rt, k - 2);
Allnode = sz[v];
Maxt = n;
get_root(v, rt);
divide(root);
}
}
}
int main() {
#ifndef ONLINE_JUDGE
FIN
#endif
while(scanf("%d%d", &n, &k) != EOF) {
memset(head, -1, sizeof(head));
tot = 0;
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
add_edge(u, v, 1);
add_edge(v, u, 1);
}
memset(vis, 0, sizeof(vis));
Allnode = n;
Maxt = INF;
get_root(1, 0);
divide(root);
printf("%lld
", ans/2);
}
return 0;
}