题面
解析
考虑任一长度为$n-2$的序列,序列中每个数权值为$[1,n]$,这个序列($prufer$序列)唯一对应一棵形态确定的$n$个节点的树,反之亦然,即树和$prufer$序列是双射关系。
那么可以将问题转化为枚举$prufer$序列:$$egin{align*}Ans&=sum_{sum_{i}d_i=n-2}frac{(n-2)!}{prod_id_i!}(prod_ia_i^{d_i+1}(d_i+1)^m)*(sum_i(d_i+1)^m)\&=(n-2)!*(prod_ia_i)*(sum_{sum_{i}d_i=n-2}(prod_ifrac{a_i^{d_i}(d_i+1)^m}{d_i!})*(sum_i(d_i+1)^m))\&=(n-2)!*(prod_ia_i)*(sum_{sum_{i}d_i=n-2}sum_i((d_i+1)^m*prod_jfrac{a_j^{d_j}(d_j+1)^m}{d_j!}))\&=(n-2)!*(prod_ia_i)*(sum_{sum_{i}d_i=n-2}sum_i(frac{a_i^{d_i}(d_i+1)^{2m}}{d_i!}*prod_{j eq i}frac{a_j^{d_j}(d_j+1)^m}{d_j!}))\end{align*}$$
设$$A(x)=sum_{i=0}^{infty}frac{(i+1)^{2m}}{i!}x^i\ B(x)=sum_{i=0}^{infty}frac{(i+1)^{m}}{i!}x^i\ F(x)=sum_iA(a_ix)prod_{j eq i}B(a_jx)$$
对$F(x)$化简:$$egin{align*}F(x)&=sum_iA(a_ix)prod_{j eq i}B(a_jx)\&=(sum_ifrac{A(a_ix)}{B(a_ix)})*prod_iB(a_ix)\&=(sum_ifrac{A(a_ix)}{B(a_ix)})*exp(ln(prod_iB(a_ix)))\&=(sum_ifrac{A(a_ix)}{B(a_ix)})*exp(sum_iln(B(a_ix)))end{align*}$$
再设$$C(x)=frac{A(x)}{B(x)}\ D(x)=ln(B(x))$$
有:$$[x^j](sum_iln(B(a_ix))) = ([x^j]D(x))*sum_ia_i^j\ ([x^j]sum_{i}frac{A(a_ix)}{B(a_ix)})=([x^j]C(x))*sum_ia_i^j$$
求出$C(x)$与$D(x)$,对它们的第$i$项乘以$sum_ja_j^i$,也就是需要求数列的$i$次方和,我在生成函数小结里有写,这里就不展开说了。
最终答案:$$Ans=(n-2)!*(prod_ia_i)*[x^{n-2}]F(x)$$
$O(Nlog^2 N)$
代码:
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<vector> #define ls (x << 1) #define rs ((x << 1) | 1) using namespace std; typedef long long ll; const int maxn = 60005, mod = 998244353, g = 3; int add(int x, int y) { return x + y < mod? x + y: x + y - mod; } int rdc(int x, int y) { return x - y < 0? x - y + mod: x - y; } ll qpow(ll x, int y) { ll ret = 1; while(y) { if(y&1) ret = ret * x % mod; x = x * x % mod; y >>= 1; } return ret; } int n, m, lim, bit, rev[maxn<<1], a[maxn]; ll ginv, fac[maxn], fnv[maxn], inv[maxn]; ll A[maxn<<1], B[maxn<<1], c[maxn<<1], d[maxn<<1], ln[maxn<<1], iv[maxn<<1], f[maxn<<1], h[maxn<<1]; vector<int> G[maxn<<1]; void init() { ginv = qpow(g, mod - 2); fac[0] = 1; for(int i = 1; i <= n; ++i) fac[i] = fac[i-1] * i % mod; inv[0] = inv[1] = fnv[0] = fnv[1] = 1; for(int i = 2; i <= n; ++i) { inv[i] = (mod - mod / i) * inv[mod%i] % mod; fnv[i] = fnv[i-1] * inv[i] % mod; } } void NTT_init(int x) { lim = 1; bit = 0; while(lim <= x) { lim <<= 1; ++ bit; } for(int i = 1; i < lim; ++i) rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1)); } void NTT(ll *x, int y) { for(int i = 1; i < lim; ++i) if(i < rev[i]) swap(x[i], x[rev[i]]); ll wn, w, u, v; for(int i = 1; i < lim; i <<= 1) { wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1)); for(int j = 0; j < lim; j += (i << 1)) { w = 1; for(int k = 0; k < i; ++k) { u = x[j+k]; v = x[j+k+i] * w % mod; x[j+k] = add(u, v); x[j+k+i] = rdc(u, v); w = w * wn % mod; } } } if(y == -1) { ll linv = qpow(lim, mod - 2); for(int i = 0; i < lim; ++i) x[i] = x[i] * linv % mod; } } void get_inv(ll *x, ll *y, int len) { if(len == 1) { x[0] = qpow(y[0], mod - 2); return ; } get_inv(x, y, (len + 1) >> 1); for(int i = 0; i < len; ++i) c[i] = y[i]; NTT_init(len << 1); NTT(x, 1); NTT(c, 1); for(int i = 0; i < lim; ++i) { x[i] = x[i] * rdc(2, c[i] * x[i] % mod) % mod; c[i] = 0; } NTT(x, -1); for(int i = len; i < lim; ++i) x[i] = 0; } void get_ln(ll *x, ll *y, int len) { for(int i = 0; i < len; ++i) x[i] = y[i+1] * (i + 1) % mod; get_inv(iv, y, len); NTT_init(len << 1); NTT(x, 1); NTT(iv, 1); for(int i = 0; i < lim; ++i) { x[i] = x[i] * iv[i] % mod; iv[i] = 0; } NTT(x, -1); for(int i = len - 1; i >= 1; --i) x[i] = x[i-1] * inv[i] % mod; x[0] = 0; for(int i = len; i < lim; ++i) x[i] = 0; } void get_exp(ll *x, ll *y, int len) { if(len == 1) { x[0] = 1; return ; } get_exp(x, y, (len + 1) >> 1); get_ln(ln, x, len); for(int i = 0; i < len; ++i) { c[i] = add(i == 0, rdc(y[i], ln[i])); ln[i] = 0; } NTT_init(len << 1); NTT(x, 1); NTT(c, 1); for(int i = 0; i < lim; ++i) { x[i] = x[i] * c[i] % mod; c[i] = 0; } NTT(x, -1); for(int i = len; i < lim; ++i) x[i] = 0; } void solve(int x, int l, int r, int *y) { if(l == r) { G[x].push_back(1); G[x].push_back(rdc(0, y[l])); return; } int mid = (l + r) >> 1; solve(ls, l, mid, y); solve(rs, mid + 1, r, y); for(int i = 0; i <= mid - l + 1; ++i) c[i] = G[ls][i]; for(int i = 0; i <= r - mid; ++i) d[i] = G[rs][i]; NTT_init(r - l + 1); NTT(c, 1); NTT(d, 1); for(int i = 0; i < lim; ++i) { c[i] = c[i] * d[i] % mod; d[i] = 0; } NTT(c, -1); for(int i = 0; i <= r - l + 1; ++i) { G[x].push_back(c[i]); c[i] = 0; } for(int i = r - l + 2; i < lim; ++i) c[i] = 0; } int main() { scanf("%d%d", &n, &m); init(); ll ans = fac[n-2]; for(int i = 1; i <= n; ++i) { scanf("%d", &a[i]); ans = ans * a[i] % mod; } solve(1, 1, n, a); for(int i = 0; i <= n; ++i) d[i] = G[1][i]; get_ln(f, d, n + 1); for(int i = n; i >= 1; --i) { //f[i] = f[i] * i % mod; f[i] = rdc(0, f[i] * i % mod); d[i] = 0; } f[0] = n; ll tmp; for(int i = 0; i <= n; ++i) { tmp = qpow(i + 1, m); B[i] = tmp * fnv[i] % mod; A[i] = B[i] * tmp % mod; } get_ln(d, B, n + 1); get_inv(h, B, n + 1); NTT_init(n << 1); NTT(A, 1); NTT(h, 1); for(int i = 0; i < lim; ++i) A[i] = A[i] * h[i] % mod; NTT(A, -1); for(int i = 0; i <= n; ++i) { A[i] = A[i] * f[i] % mod; d[i] = d[i] * f[i] % mod; } for(int i = n + 1; i < lim; ++i) A[i] = 0; memset(B, 0, sizeof(B)); get_exp(B, d, n + 1); NTT_init(n << 1); NTT(A, 1); NTT(B, 1); for(int i = 0; i < lim; ++i) A[i] = A[i] * B[i] % mod; NTT(A, -1); printf("%lld", ans * A[n-2] % mod); return 0; }