Description
定义一次点分治的复杂度是所有分治中心分治时的子树大小之和。
给定一棵树,问所有点等概率被选做重心,点分治的期望复杂度。
Solution
根据期望的线性性,答案等价于每个点在点分树上的深度期望之和。
思路是从点对的角度考虑某一个点是否会产生贡献。
[E(depth[x])=sum_{y=1}^n P(xin subtree[y])
]
也就是 (x) 在点分树上在 (1dots n) 的子树中的概率和。
考虑点分树上 (y) 是 (x) 的祖先的条件,要求 (x) 和 (y) 构成的这条链上第一个在点分治过程中被删除的点是 (y) ,由于链上被选中的概率相等,因此这个概率为 (frac{1}{dist(x,y) + 1})。
所以答案为
[sum_{x=1}^nsum_{j=1}^n frac{1}{dis(i,j) + 1}=sum_{len = 0}^n frac{cnt[i]}{i + 1}
]
因此需要点分治求长度为 (i) 的路径条数 (cnt[i]) ,注意到合并的时候是卷积的形式。
容斥做法
不考虑重复路径,把子树 dfs 一遍,直接自己进行卷积,再去掉子树内重复计数的路径即可。
每一层最差以自己的 (size) 作为长度进行卷积,因此复杂度为 (mathcal O(nlog^2 n))
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;
inline int rd() {
int x = 0;
char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) {
x = x * 10 + (c ^ 48); c = getchar();
}
return x;
}
inline void print(ll x) {
int y = 10, len = 1;
while(y <= x) {y *= 10; ++len;}
while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
putchar('
');
}
inline int fpow(int x, int t = mod - 2) {
int res = 1;
while (t) {
if (t & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod; t >>= 1;
}
return res;
}
int mxlen = (1 << 16), w[2][N], rev[N];
inline int mo(int x) {
return x >= mod ? x - mod : x;
}
inline void init() {
int per = fpow(3, (mod - 1) / mxlen);
int invper = fpow(per);
w[0][0] = w[1][0] = 1;
for (int i = 1; i < mxlen; ++i) {
w[0][i] = 1ll * w[0][i - 1] * per % mod;
w[1][i] = 1ll * w[1][i - 1] * invper % mod;
}
}
inline int Rev(int n) {
int len = 1, bit = 0;
while (len <= n) len <<= 1, ++bit;
for (int i = 0; i < len; ++i)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
return len;
}
inline void NTT(int *f, int len, int o) {
for (int i = 0; i < len; ++i)
if (i > rev[i]) swap(f[i], f[rev[i]]);
for (int i = 1; i < len; i <<= 1) {
int wn = mxlen / (i << 1);
for (int j = 0; j < len; j += (i << 1)) {
int nw = 0, x, y;
for (int k = 0; k < i; ++k, nw += wn) {
x = f[j + k];
y = 1ll * w[o][nw] * f[i + j + k] % mod;
f[j + k] = mo(x + y);
f[i + j + k] = mo(x - y + mod);
}
}
}
if (o == 1) {
int invl = fpow(len);
for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
}
}
bool vis[N];
int n, m, tot, totn, mx, rt, mxd;
int bkt[N], cnt[N], sz[N], hd[N];
struct edge{int to, nxt;} e[N << 1];
inline void add(int u, int v) {
e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}
void getrt(int u, int fa) {
sz[u] = 1;
int mxs = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getrt(v, u);
sz[u] += sz[v];
mxs = max(mxs, sz[v]);
}
mxs = max(mxs, totn - sz[u]);
if (mxs < mx) {mx = mxs; rt = u;}
}
void getsz(int u, int fa) {
sz[u] = 1;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getsz(v, u); sz[u] += sz[v];
}
}
void dfs(int u, int fa, int dep) {
++bkt[dep]; mxd = max(mxd, dep);
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) dfs(v, u, dep + 1);
}
inline void mul(int *a, int len, int o) {
len = Rev(len << 1);
NTT(a, len, 0);
for (int i = 0; i < len; ++i) a[i] = 1ll * a[i] * a[i] % mod;
NTT(a, len, 1);
if (o > 0) for (int i = 0; i < len; ++i) cnt[i + 1] += a[i];
else for (int i = 0; i < len; ++i) cnt[i + 3] -= a[i];
for (int i = 0; i < len; ++i) a[i] = 0;
}
inline void calc(int u, int o) {
mxd = 0;
dfs(u, 0, 0);
mul(bkt, mxd, o);
}
void divide(int u) {
vis[u] = 1;
calc(u, 1);
for (int i = hd[u], v; i; i = e[i].nxt)
if (!vis[v = e[i].to]) {
calc(v, -1);
getsz(v, u);
totn = mx = sz[v]; rt = v;
getrt(v, 0); divide(rt);
}
}
int main() {
init();
n = rd();
for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
mx = totn = n;
getrt(1, 0); divide(rt);
double ans = 0.0;
for (int i = 1; i <= n + 1; ++i) ans += (double) cnt[i] / i;
printf("%.4lf", ans);
return 0;
}
子树按秩合并做法
在点分治求路径条数时,我们尝试用按秩合并的思路去搞,也就是将子树按照最深深度排序,然后逐个合并计算答案。
开始的时候只有 (bkt[0]=1),然后按顺序卷每一个子树求出来的计数数组 (bktson) 。
把贡献直接计算,然后再将 (bktson) 按位加到 (bkt) 上。
考虑复杂度,将子树按照深度从小到大排序后,每次卷积得到的新的链长不会超过新合并的子树深度的二倍,所以每次卷积的数组长度为 (mxdep[v]) 的,且每个位置只会和其父节点卷积一次,因此总复杂度为 (mathcal O(nlog^2 n))
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;
inline int rd() {
int x = 0;
char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) {
x = x * 10 + (c ^ 48); c = getchar();
}
return x;
}
inline void print(ll x) {
int y = 10, len = 1;
while(y <= x) {y *= 10; ++len;}
while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
putchar('
');
}
inline int fpow(int x, int t = mod - 2) {
int res = 1;
while (t) {
if (t & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod; t >>= 1;
}
return res;
}
int mxlen = (1 << 16), w[2][N], rev[N];
inline int mo(int x) {
return x >= mod ? x - mod : x;
}
inline void init() {
int per = fpow(3, (mod - 1) / mxlen);
int invper = fpow(per);
w[0][0] = w[1][0] = 1;
for (int i = 1; i < mxlen; ++i) {
w[0][i] = 1ll * w[0][i - 1] * per % mod;
w[1][i] = 1ll * w[1][i - 1] * invper % mod;
}
}
inline int Rev(int n) {
int len = 1, bit = 0;
while (len <= n) len <<= 1, ++bit;
for (int i = 0; i < len; ++i)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
return len;
}
inline void NTT(int *f, int len, int o) {
for (int i = 0; i < len; ++i)
if (i > rev[i]) swap(f[i], f[rev[i]]);
for (int i = 1; i < len; i <<= 1) {
int wn = mxlen / (i << 1);
for (int j = 0; j < len; j += (i << 1)) {
int nw = 0, x, y;
for (int k = 0; k < i; ++k, nw += wn) {
x = f[j + k];
y = 1ll * w[o][nw] * f[i + j + k] % mod;
f[j + k] = mo(x + y);
f[i + j + k] = mo(x - y + mod);
}
}
}
if (o == 1) {
int invl = fpow(len);
for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
}
}
bool vis[N];
double ans = 0.0;
int n, m, tot, totn, mx, rt;
int bkt[N], sz[N], hd[N];
struct edge{int to, nxt;} e[N << 1];
inline void add(int u, int v) {
e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}
void getrt(int u, int fa) {
sz[u] = 1;
int mxs = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getrt(v, u);
sz[u] += sz[v];
mxs = max(mxs, sz[v]);
}
mxs = max(mxs, totn - sz[u]);
if (mxs < mx) {mx = mxs; rt = u;}
}
void getsz(int u, int fa) {
sz[u] = 1;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getsz(v, u); sz[u] += sz[v];
}
}
int res[N], tmp[N];
inline int mul(int *a, int *b, int lena, int lenb) {
int len = Rev(lenb << 1);
for (int i = 0; i < lena; ++i) res[i] = a[i];
for (int i = lena; i < len; ++i) res[i] = 0;
for (int i = 0; i < lenb; ++i) tmp[i] = b[i];
for (int i = lenb; i < len; ++i) tmp[i] = 0;
NTT(res, len, 0); NTT(tmp, len, 0);
for (int i = 0; i < len; ++i) res[i] = 1ll * res[i] * tmp[i] % mod;
NTT(res, len, 1);
for (int i = 0; i < len; ++i) ans += 2.0 * res[i] / (i + 1);
return len;
}
int mxd[N], s[N], bkts[N];
inline bool cmp(int x, int y) {return mxd[x] < mxd[y];}
int dfs(int u, int fa, int dep) {
int resd = dep;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) resd = max(resd, dfs(v, u, dep + 1));
return resd;
}
void dfs2(int u, int fa, int dep) {
++bkts[dep];
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) dfs2(v, u, dep + 1);
}
void divide(int u) {
vis[u] = 1;
s[0] = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if (!vis[v = e[i].to]) {
s[++s[0]] = v;
mxd[v] = dfs(v, u, 1);
}
sort(s + 1, s + 1 + s[0], cmp);
bkt[0] = 1;
int nowlen = 1;
for (int i = 1, v; i <= s[0]; ++i) {
dfs2(v = s[i], 0, 1);
nowlen = mul(bkt, bkts, nowlen, mxd[v] + 1);
for (int i = 0; i <= mxd[v]; ++i) {
bkt[i] += bkts[i]; bkts[i] = 0;
}
}
for (int i = 0; i <= nowlen; ++i) bkt[i] = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if (!vis[v = e[i].to]) {
getsz(v, u);
totn = mx = sz[v]; rt = v;
getrt(v, 0); divide(rt);
}
}
int main() {
init();
n = rd();
for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
mx = totn = n;
getrt(1, 0); divide(rt);
printf("%.4lf", ans + n);
return 0;
}