@description@
本题包含三个问题:
问题 0:已知两棵 n 个结点的树的形态(两棵树的结点标号均为 1~n),其中第一棵树是红树,第二棵树是蓝树。要给予每个结点一个 [1, y] 中的整数,使得对于任意两个节点 p, q,如果存在一条路径 (a1 = p, a2, ..., am = q) 同时属于这两棵树,则 p, q 必须被给予相同的数。求给予数的方案数。
问题 1:已知蓝树,对于红树的所有 (n^{n-2}) 种选择方案,求问题 0 的答案之和。
问题 2:对于蓝树的所有 (n^{n-2}) 种选择方案,求问题 1 的答案之和。
@solution@
说点人话,若两棵树边集的交集为 S,则答案等于 (y^{n - |S|})。
前排提醒:下面可能会出现类似 (1 - y) 作分母的情况,当 y = 1 时没有意义。所以需要优先特判掉。
注意 y = 1 时 |S| 并不会影响,所以只取决于有多少种可能的情况。
@问题 0@
相信大家都会做。
@问题 1@
不难想到一个指数级的思路:枚举交集 S,记 f(S) 表示满足要求的树的个数。
交集恰好为 S 显然不好做,而且看起来很好容斥。我们枚举 T,计算交集包含 T 的情况,记为 g(T)。
稍微思考一下得到容斥式子 (f(S) = sum_{Ssubseteq T}(-1)^{|T|-|S|}g(T))。
则最终答案有如下式子:
尝试消去 S:
用一个二项式定理就可以得到 (ans = y^nsum_{T}g(T)(y^{-1} - 1)^{|T|})。
不妨先记 (u = (y^{-1} - 1)),则 (ans = y^nsum_{T}g(T)u^{|T|})。
尽管如此还是一个指数级算法。考虑 g(T) 应该怎么求,然后优化成多项式算法。
如果给定边集 T,只要另一棵树中包含 T 中这些边即可。因此相当于先用 T 中的边将 1~n 的点连成 k 个大小为 a1, a2, ..., ak 的连通块,然后再连成一棵树的方案数。
用 matrix-tree / prufer 可以证明这个方案数为 (g(T) = n^{k-2} imesprod_{i=1}^{k}a_i)(证明详见下面的补充部分)。
由于 T 中的边连成的连通块个数 (k = n-|T|),所以将原式进一步改写为:
可以作 O(n^2) 的树形 dp:记 dp[i][j] 表示以 i 为根的子树被分成了若干连通块,其中 i 所在的连通块大小为 j,其他连通块的总贡献为 dp[i][j]。
当然可以更简单:考虑 (a_i imes n imes u^{-1}) 的组合意义。即大小为 (a_i) 的连通块中选择一个,贡献 (n imes u^{-1})。
然后记 dp[0/1][i] 表示 i 所在的连通块是否有点贡献了 (n imes u^{-1}),这样子就是 O(n) 的树形 dp 了。
@问题 2@
如果你像我一开始一样,从上面的 dp[0/1][i] 入手,最后就会陷入两个生成函数互相卷积的怪圈中,只能分治 fft O(nlog^2n) 求解。。。
考虑依然是容斥,其它过程都与上面一样,只是 g(T) 的计算式子变为 (g(T) = (n^{k-2} imesprod_{i=1}^{k}a_i)^2)(因为要枚举两棵树嘛)。
那么最终答案 (h[n] = frac{y^n imes u^n}{n^4} imessum_{T}(prod_{i=1}^{k}(a_i^2 imes n^2 imes u^{-1})))。
现在枚举 T 反而不好办了。我们考虑直接枚举序列 a,算出有多少边集 T。不妨令点 1 所在的连通块大小为 a1,枚举与点 1 在同一连通块的点得到 h 的转移:
上面那个可以直接 O(n^2) 做了。不过还可以进一步优化:
记 (p[i] = (i+1)^2 imes n^2 imes u^{-1} imes (i+1)^{i-1}),则上面的卷积又可以写作 (h[n+1] = sum_{i=0}^{n}C_{n}^{i} imes p[i] imes h[n-i])。
这是一个经典的卷积式子,可以写成指数型生成函数然后求多项式 exp(具体可见下面的补充部分)。
时间复杂度 O(nlogn)。
@补充部分@
对上面所提到的两个问题的细节补充。
(1)1~n 的点连成 k 个大小为 a1, a2, ..., ak 的连通块,然后再连成一棵树的方案数为 (n^{k-2} imesprod_{i=1}^{k}a_i)。
证明我选择的是 prufer 序列(懒得写matrix-tree的矩阵证法,网上应该找得到)。
由于一个数在 prufer 序列中的出现次数为它的度数减一,又因为从某个大小为 ai 的连通块连出去一条边有 ai 种选择,所以有:
关于后面那个怎么来的,其实是逆用多项式的展开:
(2)关于指数型生成函数的 exp 对应的卷积意义。
首先要认识到,对于指数型生成函数而言,积分相等于右移,求导相当于左移。
假如令 (f(x) = sum_{i=0}frac{a_{i}}{i!}x^i),则 (f'(x) = sum_{i=0}frac{a_{i+1}}{i!}x^i),(int f(x) = sum_{i=1}frac{a_{i-1}}{i!}x^i)。
根据求导法则,有 (ln(f(x))' = frac{f'(x)}{f(x)}),即 (ln(f(x))' imes f(x) = f'(x))。
如果记 (g(x) = ln(f(x))' = sum_{i=0}frac{b_{i}}{i!}x^i),比较第 n 项等式两边的系数,可以得到:
然后可以推出 (a_{n+1} = sum_{i=0}^{n}C_n^i imes a_i imes b_{n-i}),就是我们题目中的卷积式子。
@accepted code@
#include <set>
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int MOD = 998244353;
const int MAXN = 400000;
struct mint{
int x;
mint(int _x = 0) : x(_x) {}
friend mint operator + (mint a, const mint &b) {return (a.x + b.x) % MOD;}
friend mint operator - (mint a, const mint &b) {return (a.x + MOD - b.x) % MOD;}
friend mint operator * (mint a, const mint &b) {return 1LL * a.x * b.x % MOD;}
friend void operator += (mint &a, const mint &b) {a = a + b;}
friend void operator -= (mint &a, const mint &b) {a = a - b;}
friend void operator *= (mint &a, const mint &b) {a = a * b;}
friend mint mpow(mint b, int p) {
if( b.x == 1 ) return 1;
mint ret = 1;
while( p ) {
if( p & 1 ) ret = ret * b;
b = b * b;
p >>= 1;
}
return ret;
}
friend mint operator / (mint a, const mint &b) {return a * mpow(b, MOD - 2);}
friend void operator /= (mint &a, const mint &b) {a = a / b;}
};
int n, y, op;
void solve0() {
if( op == 0 ) printf("%d
", 1);
else if( op == 1 ) printf("%d
", mpow((mint)n, n - 2).x);
else printf("%d
", mpow((mint)n, 2*(n - 2)).x);
}
set<pair<int, int> >e;
void solve1() {
int ans = 0;
for(int i=1;i<n;i++) {
int u, v; scanf("%d%d", &u, &v);
if( u > v ) swap(u, v);
e.insert(make_pair(u, v));
}
for(int i=1;i<n;i++) {
int u, v; scanf("%d%d", &u, &v);
if( u > v ) swap(u, v);
if( e.count(make_pair(u, v)) ) ans++;
}
printf("%d
", mpow((mint)y, n - ans).x);
}
struct edge{
edge *nxt; int to;
}edges[2*MAXN + 5], *adj[MAXN + 5], *ecnt = edges;
void addedge(int u, int v) {
edge *p = (++ecnt);
p->to = v, p->nxt = adj[u], adj[u] = p;
p = (++ecnt);
p->to = u, p->nxt = adj[v], adj[v] = p;
}
mint dp[2][MAXN + 5], del;
void dfs(int x, int f) {
dp[0][x] = 1, dp[1][x] = del;
for(edge *p=adj[x];p;p=p->nxt) {
if( p->to == f ) continue;
dfs(p->to, x);
dp[1][x] = dp[1][x] * dp[1][p->to] + dp[1][x] * dp[0][p->to] + dp[0][x] * dp[1][p->to];
dp[0][x] = dp[0][x] * dp[1][p->to] + dp[0][x] * dp[0][p->to];
}
}
void solve2() {
for(int i=1;i<n;i++) {
int u, v; scanf("%d%d", &u, &v);
addedge(u, v);
}
mint u = 1; u = (u - y) / y;
mint p = mpow(y * u, n) / n / n;
del = n / u, dfs(1, 0);
printf("%d
", (dp[1][1] * p).x);
}
namespace poly{
const mint G = 3;
mint w[20], iw[20], inv[MAXN + 5];
void init() {
inv[1] = 1;
for(int i=2;i<=MAXN;i++)
inv[i] = MOD - inv[MOD%i]*(MOD/i);
for(int i=0;i<20;i++)
w[i] = mpow(G, (MOD-1)/(1<<i)), iw[i] = 1 / w[i];
}
void ntt(mint *A, int n, int type) {
for(int i=0,j=0;i<n;i++) {
if( i < j ) swap(A[i], A[j]);
for(int k=(n>>1);(j^=k)<k;k>>=1);
}
for(int i=1;(1<<i)<=n;i++) {
int s = (1 << i), t = (s >> 1);
mint u = (type == 1 ? w[i] : iw[i]);
for(int j=0;j<n;j+=s) {
mint p = 1;
for(int k=0;k<t;k++,p*=u) {
mint x = A[j + k], y = A[j + k + t];
A[j + k] = x + y*p, A[j + k + t] = x - y*p;
}
}
}
if( type == -1 ) {
for(int i=0;i<n;i++)
A[i] *= inv[n];
}
}
mint t1[MAXN + 5], t2[MAXN + 5];
int length(int n) {
int l; for(l = 1; l < n; l <<= 1);
return l;
}
void mul(mint *A, int nA, mint *B, int nB, mint *C) {
int nC = (nA + nB - 1), len = length(nC);
for(int i=0;i<nA;i++) t1[i] = A[i];
for(int i=nA;i<len;i++) t1[i] = 0;
for(int i=0;i<nB;i++) t2[i] = B[i];
for(int i=nB;i<len;i++) t2[i] = 0;
ntt(t1, len, 1), ntt(t2, len, 1);
for(int i=0;i<len;i++) C[i] = t1[i] * t2[i];
ntt(C, len, -1);
}
mint t3[MAXN + 5], t4[MAXN + 5];
void pinv(mint *A, mint *B, int n) {
if( n == 1 ) {
B[0] = 1 / A[0];
return ;
}
int m = (n + 1) >> 1;
pinv(A, B, m);
int len = length(n << 1);
for(int i=0;i<m;i++) t3[i] = B[i];
for(int i=m;i<len;i++) t3[i] = 0;
for(int i=0;i<n;i++) t4[i] = A[i];
for(int i=n;i<len;i++) t4[i] = 0;
ntt(t3, len, 1), ntt(t4, len, 1);
for(int i=0;i<len;i++)
B[i] = t3[i] * (2 - t3[i] * t4[i]);
ntt(B, len, -1);
}
void pdif(mint *A, mint *B, int n) {
for(int i=1;i<n;i++)
B[i-1] = A[i] * i;
}
void pint(mint *A, mint *B, int n) {
for(int i=n-1;i>=0;i--)
B[i+1] = A[i] * inv[i + 1];
B[0] = 0;
}
mint t5[MAXN + 5], t6[MAXN + 5];
void ln(mint *A, mint *B, int n) {
pdif(A, t5, n), pinv(A, t6, n);
mul(t5, n - 1, t6, n, t5);
pint(t5, B, n);
}
mint t7[MAXN + 5], t8[MAXN + 5];
void exp(mint *A, mint *B, int n) {
if( n == 1 ) {
B[0] = 1;
return ;
}
int m = (n + 1) >> 1;
exp(A, B, m);
for(int i=0;i<m;i++) t7[i] = B[i];
for(int i=m;i<n;i++) t7[i] = 0;
ln(t7, t8, n);
for(int i=0;i<n;i++) t7[i] = A[i] - t8[i];
t7[0].x += 1;
for(int i=0;i<m;i++) t8[i] = B[i];
mul(t7, n, t8, m, B);
}
}
mint fct[MAXN + 5], ifct[MAXN + 5];
void init() {
poly::init(); fct[0] = 1;
for(int i=1;i<=MAXN;i++) fct[i] = fct[i-1] * i;
ifct[MAXN] = 1 / fct[MAXN];
for(int i=MAXN-1;i>=0;i--) ifct[i] = ifct[i+1] * (i+1);
}
/*
mint comb(int n, int m) {
return fct[n] * ifct[m] * ifct[n-m];
}
*/
mint f[MAXN + 5], g[MAXN + 5];
void solve3() {
init();
mint u = 1; u = (u - y) / y;
mint p = mpow(y * u, n) / (mint(n) * n * n * n);
del = n / u * n;
/*
for(int i=0;i<n;i++)
g[i] = mpow(mint(i+1), i-1) * del * mint(i+1) * mint(i+1);
f[0] = 1;
for(int i=0;i<n;i++)
for(int j=0;j<=i;j++)
f[i+1] += comb(i, j)*g[i-j]*f[j];
printf("%d
", (f[n] * p).x);
*/
for(int i=0;i<n;i++)
g[i] = mpow(mint(i+1), i-1) * del * mint(i+1) * mint(i+1), g[i] *= ifct[i];
poly::pint(g, g, n);
poly::exp(g, f, n + 1);
printf("%d
", (f[n] * p * fct[n]).x);
}
int main() {
scanf("%d%d%d", &n, &y, &op);
if( y == 1 ) solve0();
else if( op == 0 ) solve1();
else if( op == 1 ) solve2();
else if( op == 2 ) solve3();
}
@details@
讲道理,这道题并不算太难分析。
不过可以学到很多分析组合计数的知识与技巧。