@description@
在一个 s 个点的图中,存在 s - n 条边,使图中形成了 n 个连通块,第 i 个连通块中有 (a_i) 个点。
现在我们需要再连接 n - 1 条边,使该图变成一棵树。对一种连边方案,设原图中第 i 个连通块连出了 (d_i) 条边,那么这棵树 T 的价值为:
你的任务是求出所有可能的生成树的价值之和,对 998244353 取模。
原题戳我。
@solution@
@正文@
注意到 (d_i) 为度数,那么考虑 prufer 序列,直接写出答案表达式:
其中 (b_i + 1 = d_i)。
作一些简单的变形:
引入生成函数。如果记 (P(x) = sum_{i=0}frac{(i + 1)^{2m} imes x^i}{i!}),(Q(x) = sum_{i=0}frac{(i + 1)^{m} imes x^i}{i!}),则:
注意到 (frac{P(a_i x)}{Q(a_i x)}) 其实就是 (frac{P(x)}{Q(x)}) 的第 k 项乘上 (a_i^{k})。
也就是说 (sum_{i=1}^{n}frac{P(a_i x)}{Q(a_i x)}) 就是 (frac{P(x)}{Q(x)}) 的第 k 项乘上 (sum_{i=1}^{n}a_i^{k}),而 (sum_{i=1}^{n}a_i^{k}) 是可以快速求出的(在补充部分介绍)。
尝试把 (prod_{i=1}^{n}Q(a_i x)) 也化成加法形式:利用对数,可以得到 (prod_{i=1}^{n}Q(a_i x) = exp(sum_{i=1}^{n}ln(Q(a_i x))))
之后就没有了。只要求出了 (sum_{i=1}^{n}a_i^{k}),剩下的都是模板。
@补充@
关于如何求 (sum_{i=1}^{n}a_i^{k}),其实方法比较多,这里介绍一种:
注意到 (ln(1 - x) = -sum_{i=1}frac{x^i}{i}),那么只要求出 (sum_{i=1}^{n}ln(1 - a_ix)),也就求出了 (sum_{i=1}^{n}a_i^{k})。
利用对数的性质,有 (sum_{i=1}^{n}ln(1 - a_ix) = ln(prod_{i=1}^{n}(1 - a_ix)))。
然后里面那个式子分治 fft 可以 O(nlog^2n) 搞定,这样一来总时间复杂度其实就是 O(nlog^2n)。
@accepted code@
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN = 4*30000;
const int MOD = 998244353;
struct mint{
int x;
mint(int _x=0) : x(_x) {}
friend mint operator + (mint a, mint b) {
return a.x + b.x >= MOD ? a.x + b.x - MOD : a.x + b.x;
}
friend mint operator - (mint a, mint b) {
return a.x - b.x < 0 ? a.x - b.x + MOD : a.x - b.x;
}
friend mint operator * (mint a, mint b) {
return (int)(1LL * a.x * b.x % MOD);
}
friend mint pow_mod(mint b, int p) {
mint ret = 1;
while( p ) {
if( p & 1 ) ret *= b;
b *= b;
p >>= 1;
}
return ret;
}
friend mint operator / (mint a, mint b) {
return a * pow_mod(b, MOD - 2);
}
friend void operator += (mint &a, mint b) {a = a + b;}
friend void operator -= (mint &a, mint b) {a = a - b;}
friend void operator *= (mint &a, mint b) {a = a * b;}
friend void operator /= (mint &a, mint b) {a = a / b;}
};
namespace poly{
const mint G = 3;
mint w[20], iw[20], inv[MAXN + 5];
void init() {
for(int i=0;i<20;i++) {
w[i] = pow_mod(G, (MOD-1)/(1<<i));
iw[i] = pow_mod(w[i], MOD-2);
}
inv[1] = 1;
for(int i=2;i<=MAXN;i++)
inv[i] = 0 - (MOD/i)*inv[MOD%i];
}
void debug(mint *A, int n) {
for(int i=0;i<n;i++)
printf("%d ", A[i].x);
puts("");
}
void ntt(mint *A, int n, int type) {
for(int i=0,j=0;i<n;i++) {
if( i < j ) swap(A[i], A[j]);
for(int k=(n>>1);(j^=k)<k;k>>=1);
}
for(int i=1;(1<<i)<=n;i++) {
int s = (1 << i), t = (s >> 1);
mint u = (type == 1 ? w[i] : iw[i]);
for(int j=0;j<n;j+=s) {
mint p = 1;
for(int k=0;k<t;k++,p*=u) {
mint x = A[j+k], y = A[j+k+t];
A[j+k] = x + y*p, A[j+k+t] = x - y*p;
}
}
}
if( type == -1 ) {
mint iv = inv[n];
for(int i=0;i<n;i++)
A[i] *= iv;
}
}
int length(int n) {
int len; for(len = 1; len < n; len <<= 1);
return len;
}
void pcopy(mint *A, mint *B, int n, int l) {
for(int i=0;i<n;i++) A[i] = B[i];
for(int i=n;i<l;i++) A[i] = 0;
}
mint t1[MAXN + 5], t2[MAXN + 5];
void pmul(mint *A, int nA, mint *B, int nB, mint *C) {
int len = length(nA + nB - 1);
pcopy(t1, A, nA, len), ntt(t1, len, 1);
pcopy(t2, B, nB, len), ntt(t2, len, 1);
for(int i=0;i<len;i++) C[i] = t1[i] * t2[i];
ntt(C, len, -1);
}
mint t3[MAXN + 5], t4[MAXN + 5];
void pinv(mint *A, mint *B, int n) {
if( n == 1 ) {
B[0] = 1 / A[0];
return ;
}
int m = (n + 1) >> 1; pinv(A, B, m);
int len = length(n << 1);
pcopy(t3, A, n, len), ntt(t3, len, 1);
pcopy(t4, B, m, len), ntt(t4, len, 1);
for(int i=0;i<len;i++) B[i] = t4[i]*(2 - t3[i]*t4[i]);
ntt(B, len, -1);
}
void pdif(mint *A, mint *B, int n) {
for(int i=1;i<n;i++)
B[i-1] = A[i] * i;
}
void pint(mint *A, mint *B, int n) {
for(int i=n-1;i>=0;i--)
B[i+1] = A[i] / (i + 1);
B[0] = 0;
}
mint t5[MAXN + 5], t6[MAXN + 5];
void pln(mint *A, mint *B, int n) {
pinv(A, t5, n), pdif(A, t6, n);
pmul(t5, n, t6, n, B);
pint(B, B, n);
}
mint t7[MAXN + 5], t8[MAXN + 5];
void pexp(mint *A, mint *B, int n) {
if( n == 1 ) {
B[0] = 1;
return ;
}
int m = (n + 1) >> 1; pexp(A, B, m);
int len = length(n << 1);
pcopy(t7, B, m, len), pln(t7, t8, n), pcopy(t7, t8, n, len);
pcopy(t8, B, m, len);
for(int i=0;i<n;i++) t7[i] = A[i] - t7[i];
t7[0] = t7[0] + 1;
ntt(t7, len, 1), ntt(t8, len, 1);
for(int i=0;i<len;i++) B[i] = t7[i] * t8[i];
ntt(B, len, -1);
}
}
int n, m, k;
mint A[MAXN + 5], B[MAXN + 5];
void init() {
mint t = 1;
for(int i=0;i<n;i++,t*=i) {
mint a = 1 / t, b = pow_mod(mint(i + 1), m);
A[i] = a * b * b, B[i] = a * b;
}
poly::init();
}
mint a[MAXN + 5], f[MAXN + 5], s[MAXN + 5];
int solve(int l, int r) {
if( l == r ) {
f[l<<1] = 1, f[l<<1|1] = 0 - a[l];
return 2;
}
int mid = (l + r) >> 1;
int ls = solve(l, mid), rs = solve(mid + 1, r);
poly::pmul(f + (l<<1), ls, f + ((mid + 1) << 1), rs, f + (l << 1));
return ls + rs - 1;
}
void get_pow_sum() {
solve(0, n - 1), poly::pln(f, s, n + 1);
s[0] = n;
for(int i=1;i<=n;i++)
s[i] = 0 - s[i]*i;
}
mint t1[MAXN + 5], t2[MAXN + 5];
int main() {
scanf("%d%d", &n, &m), k = n - 2, init();
for(int i=0;i<n;i++) scanf("%d", &a[i].x);
get_pow_sum();
poly::pln(B, t1, n);
for(int i=0;i<n;i++)
t1[i] *= s[i];
poly::pexp(t1, t2, n);
poly::pinv(B, t1, n);
poly::pmul(A, n, t1, n, t1);
for(int i=0;i<n;i++)
t1[i] *= s[i];
poly::pmul(t1, n, t2, n, t1);
mint ans = t1[n - 2];
for(int i=0;i<n;i++) ans *= a[i];
for(int i=1;i<=n-2;i++) ans *= i;
printf("%d
", ans.x);
}
@details@
顺带一提,这道题还有依赖于斯特林数的 O(nmlogn) 的做法(但是我看不懂 QaQ)。