题意简述:
给定一棵 (n) 个节点的树,设它的直径是 (D),问有多少个集合满足集合中每两个点的距离都为 (D)。
( exttt{Data Range:} 1le nle 2 imes 10^5)。
考虑直径的性质:
- 树的每一条直径一定都经过一个公共点 / 一条公共边。经过的是点还是边取决于直径的长度是奇数还是偶数。
那么按照直径长度的奇偶性分类讨论,直接计算即可。
具体的:
- 若直径长度为偶数,设中点为 (mid),答案为 (prodlimits_{vin son_{mid}}(cnt_v+1)-scnt-1),其中 (cnt_v) 为 (v) 子树中距离 (v) 长度为 (frac{D}{2}-1) 的个数,(scnt) 为距离 (mid) 长度为 (frac{D}{2}) 的点的个数。
- 若直径长度为奇数,设中间的边为 ((mid,fmid)),那么答案就是 (cnt_{mid} imes cnt_{fmid}),(cnt_{mid}) 为(mid) 子树中距离 (mid) 长度为 (frac{D}{2}) 的点的个数。
代码:
#include <bits/stdc++.h>
#define DC int T = gi <int> (); while (T--)
#define DEBUG fprintf(stderr, "Passing [%s] line %d
", __FUNCTION__, __LINE__)
#define File(x) freopen(x".in","r",stdin); freopen(x".out","w",stdout)
#define fi first
#define se second
#define pb push_back
#define mp make_pair
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair <int, int> PII;
typedef pair <LL, LL> PLL;
template <typename T>
inline T gi()
{
T x = 0, f = 1; char c = getchar();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f * x;
}
const int N = 200003, M = N << 1, mod = 998244353;
int n;
int tot, head[N], ver[M], nxt[M];
int fa[N];
int mx, lft, rght;
int cnt;
inline void add(int u, int v) {ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;}
inline int qpow(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod, y >>= 1;
}
return res;
}
void dfs(int u, int f, int dis)
{
if (dis > mx) mx = dis, rght = u;
fa[u] = f;
for (int i = head[u]; i; i = nxt[i])
{
int v = ver[i];
if (v == f) continue;
dfs(v, u, dis + 1);
}
}
void dfsson(int u, int f, int tar, int dis)
{
if (tar == dis) ++cnt;
for (int i = head[u]; i; i = nxt[i])
{
int v = ver[i];
if (v == f) continue;
dfsson(v, u, tar, dis + 1);
}
}
int main()
{
//freopen(".in", "r", stdin); freopen(".out", "w", stdout);
n = gi <int> ();
for (int i = 1; i < n; i+=1)
{
int u = gi <int> (), v = gi <int> ();
add(u, v), add(v, u);
}
dfs(1, 0, 0);
lft = rght, mx = 0;
dfs(lft, 0, 0);
int d = mx;
if (d % 2 == 0)
{
int mid = rght;
for (int i = 1; i <= d / 2; i+=1) mid = fa[mid];
int ans = 1, scnt = 0;
for (int i = head[mid]; i; i = nxt[i])
{
int v = ver[i];
cnt = 0;
dfsson(v, mid, d / 2 - 1, 0);
ans = 1ll * ans * (cnt + 1) % mod;
scnt += cnt;
}
printf("%d
", (ans - 1 - scnt + mod) % mod);
}
else
{
int mid = rght;
for (int i = 1; i <= d / 2; i+=1) mid = fa[mid];
int fmid = fa[mid];
dfsson(mid, fmid, d / 2, 0);
int tcnt = cnt; cnt = 0;
dfsson(fmid, mid, d / 2, 0);
printf("%lld
", 1ll * cnt * tcnt % mod);
}
return !!0;
}