Problem
问题 0:已知两棵 (n) 个节点的树的形态(两棵树的节点标号均为 (1) 至 (n)),其中第一棵树是红树,第二棵树是蓝树。要给予每个节点一个([1,y])中的整数,使得对于任意两个节点 (p,q),如果存在一条路径 ((a_1=p,a_2,cdots,a_m=q)) 同时属于这两棵树,则 (p,q) 必须被给予相同的数。求给予数的方案数。
问题 1:已知蓝树,对于红树的所有 (n^{n-2}) 种选择方案,求问题 0 的答案之和。
问题 2:对于蓝树的所有 (n^{n-2}) 种选择方案,求问题 1 的答案之和。
Algorithm
矩阵树定理、容斥、DP、生成函数计数
Solution
问题 0 :有零点难度。
如果 ((i,j)) 边在红树和蓝树都出现,那么 (i,j) 两个结点必须钦定相同的数字,缩到一块。最后就是 (y^{n-|E_1 cap E_2|})。
问题 1 :有一点难度。
首先对于蓝树,令其边集为(E)。考察某一种红树,其边集为(E'),则 (E_0=E cap E') 形成了许多连通块,每个连通块的选择方案均独立,贡献为 (y^{n-|E_0|}) 。
反向考虑有多少种红树,使得 (E_0) 相等。这里的图 (G=<V,E_0>) 形成了 (m=n-|E_0|) 个连通块,块内的边已经连好了,而块与块之间的边未定。假设这些连通块大小为 (a_1,a_2,cdots),(i) 和 (j) 两个连通块之间有 (a_ia_j) 条边。利用矩阵树定理求解问题。最后我们需要求出下面矩阵的行列式
将第 2 行至第 m-1 行元素加到第一行上
第 2 行至第 m-1 行加上第 1 行乘上(dfrac{a_i}{a_n})
最终得到
可是得到的生成树中新连接的某些边又可能在 (E-E_0) 中。考虑容斥。利用容斥的定义得到
其中 (T(E_1)) 表示图 (G=<V,E_1>) 中对连通块间生成树的计数。
而 (E_0) 又能产生贡献 (y^{n-|E_0|})。所以最后需要求解
目标就是求解 (f_m),即蓝树上选m个连通块下,红树的方案数。
考虑组合意义:相当于选取 (m) 个连通块,并在每个连通块内选择一个代表点。由于每个连通块内有 (size) 个代表点,故贡献刚刚好为乘积和。于是我们可以DP了。
如果对于每个结点记录划分了多少连通块,复杂度只能做到 (mathcal O(n^2))(树上背包)。但发现发现式子写成这样
这样每次增加一个连通块时只需把答案乘上 (frac{y}{1-y} imes n) 就行了!(乘上 (n) 由于前面推导的 (n^{m-2}prod a_i))
设 (dp_{u,0/1}) 表示到 (u) 结点时,有/没有选择代表点,令增加连通块的时刻定义与设置代表点等价,转移如下
初始化 (dp_{u,0}=1,dp_{u,1}=frac{ny}{1-y})。
所以答案为
问题 2 :有两点难度。
还是看这个式子
这里 (f_m) 应该重新定义为:两棵树形态任意,边的交形成的连通块大小为m,连通块大小的乘积和。这里采用生成函数计数。
构造生成函数
([x^i]F(x)) 表示选取一个大小为 (i) 的连通块时,内部任意连边,并且附带两次连通块大小乘积的贡献。两次的原因在后面;
表示枚举任意数量的连通块产生的贡献数。
这里可以解释贡献两次的原因了:第一次是对这些连通块形成蓝树的方案数连乘因子的贡献;第二次是形成红树的方案数连乘因子的贡献,两者独立。
所以答案即为
然后化简 (G(x))
其中
多项式 exp 即可。
复杂度(mathcal O(nlog n))。
最后注意(y=1)的情况。
#include <bits/stdc++.h>
const int N = 100005, P = 998244353;
using std::vector;
typedef vector<int> Poly;
typedef long long LL;
int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
int pow(int a, int b) {
int t = 1;
for (; b; b >>= 1, a = 1LL*a*a%P)
if (b & 1) t = 1LL*t*a%P;
return t;
}
int W[N*4], inv[N*4], fac[N*4], ifac[N*4];
void prework(int n) {
for (int i = 1; i < n; i <<= 1)
for (int j = W[i] = 1, Wn = pow(3, (P-1)/i>>1); j < i; j++)
W[i+j] = 1LL*W[i+j-1]*Wn%P;
inv[1] = fac[0] = ifac[0] = 1;
for (int i = 2; i <= n; i++) inv[i] = 1LL*(P-P/i)*inv[P%i]%P;
for (int i = 1; i <= n; i++) fac[i] = 1LL*fac[i-1]*i%P, ifac[i] = 1LL*ifac[i-1]*inv[i]%P;
}
void fft(Poly &a, int n, int opt) {
a.resize(n);
static int rev[N*4];
for (int i = 1; i < n; i++)
if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) std::swap(a[i], a[rev[i]]);
for (int q = 1; q < n; q <<= 1)
for (int p = 0; p < n; p += q<<1)
for (int i = 0, t; i < q; i++)
t = 1LL*a[p+q+i]*W[q+i]%P, a[p+q+i] = inc(a[p+i], P-t), a[p+i] = inc(a[p+i], t);
if (opt) return;
for (int i = 0, inv = pow(n, P-2); i < n; i++) a[i] = 1LL*a[i]*inv%P;
std::reverse(a.begin()+1, a.end());
}
Poly poly_inv(Poly A) {
Poly B(1, pow(A[0], P-2)), C(2);
for (int L = 1; L < A.size(); L <<= 1) {
for (int i = 0; i < L*2; i++) C[i] = i < A.size() ? A[i] : 0;
fft(B, L*4, 1), fft(C, L*4, 1);
for (int i = 0; i < L*4; i++) B[i] = (2*B[i]-1LL*B[i]*B[i]%P*C[i]%P+P)%P;
fft(B, L*4, 0); B.resize(L*2);
}
return B.resize(A.size()), B;
}
void fix(Poly &A) { int t = A.size(); while (t > 1 && !A[t-1]) t--; A.resize(t); }
int getsz(int n) { int x = 1; while (x < n) x <<= 1; return x; }
Poly operator + (Poly A, Poly B) {
A.resize(std::max(A.size(), B.size()));
for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], B[i]);
return fix(A), A;
}
Poly operator - (Poly A, Poly B) {
A.resize(std::max(A.size(), B.size()));
for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], P-B[i]);
return fix(A), A;
}
Poly operator * (Poly A, Poly B) {
int L = getsz(A.size()+B.size()-1);
fft(A, L, 1), fft(B, L, 1);
for (int i = 0; i < L; i++) A[i] = 1LL*A[i]*B[i]%P;
return fft(A, L, 0), fix(A), A;
}
Poly poly_deri(Poly A) {
for (int i = 0; i < A.size()-1; i++) A[i] = 1LL*(i+1)*A[i+1]%P;
return A.resize(A.size()-1), A;
}
Poly poly_int(Poly A) {
for (int i = A.size()-1; i; i--) A[i] = 1LL*inv[i]*A[i-1]%P;
return A[0] = 0, A;
}
Poly poly_ln(Poly A) {
Poly B = poly_deri(A) * poly_inv(A);
return B.resize(A.size()), poly_int(B);
}
Poly poly_exp(Poly A) {
Poly B(1, 1), C;
for (int L = 1; L < A.size(); L <<= 1)
B.resize(L*2), C = A + Poly(1, 1) - poly_ln(B), C.resize(L*2), B = B*C;
return B.resize(A.size()), B;
}
int n, y, op, f[N][2], k;
vector<int> e[N];
void dp(int u, int fa) {
f[u][0] = 1, f[u][1] = k;
for (int i = 0, v; i < e[u].size(); i++)
if ((v = e[u][i]) != fa) {
dp(v, u);
int f0 = f[u][0], f1 = f[u][1];
f[u][0] = ((LL)f0*f[v][0] + (LL)f0*f[v][1]) % P;
f[u][1] = ((LL)f0*f[v][1] + (LL)f1*f[v][0] + (LL)f1*f[v][1]) % P;
}
}
int main() {
scanf("%d%d%d", &n, &y, &op);
if (y == 1) { printf("%d", op == 0 ? 1 : (op == 1 ? pow(n, n-2) : 1LL*pow(n, n-2)*pow(n, n-2)%P)); return 0; }
if (op == 0) {
for (int i = 1; i < 2*n-1; i++) {
int u, v; scanf("%d%d", &u, &v);
e[u].push_back(v), e[v].push_back(u);
}
int ans = 0;
for (int i = 1; i <= n; i++) {
std::sort(e[i].begin(), e[i].end()); ans += e[i].size();
ans -= std::unique(e[i].begin(), e[i].end()) - e[i].begin();
} ans /= 2;
printf("%d", pow(y, n-ans));
} else {
k = 1LL*n*y%P*pow(P+1-y, P-2)%P;
if (op == 1) {
for (int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
e[u].push_back(v), e[v].push_back(u);
}
dp(1, 0);
printf("%d", 1LL*f[1][1]*pow(P+1-y, n)%P*pow(1LL*n*n%P, P-2)%P);
} else {
prework((n+1)*2);
Poly F(n+1);
for (int i = 1; i <= n; i++)
F[i] = 1LL*n*k%P*pow(i, i)%P*ifac[i]%P;
F = poly_exp(F);
printf("%d", 1LL*F[n]*fac[n]%P*pow(P+1-y, n)%P*pow(pow(n, 4), P-2)%P);
}
}
return 0;
}