在一个(s)个点的图中,存在(s-n)条边,使图中形成了(n)个连通块,第(i)个连通块中有(a_i)个点。
现在我们需要再连接(n-1)条边,使该图变成一棵树。对一种连边方案,设原图中第(i)个连通块连出了(d_i)条边,那么这棵树(T)的价值为:
你的任务是求出所有可能的生成树的价值之和,对(998244353)取模。
(可能只有我没读出来题目说连通块内的连边方式不计。)
树和每个点的度数可以联想到(prufer)序列。那么设(c_i)为第(i)个点在(prufer)序列中出现的次数,则(c_i=d_i-1)。考虑对于一个确定的序列(c_i),它对答案的贡献就是
第一项是这个序列对应的有标号无根树个树;第二个是因为每个连通块中的点可以任意分配这个连通块的出边;后面两个是题面定义的价值。
那么有一个暴力思路就是递推,在那之前我们把((n-2)!prod_{i=1}^na_i)看作常数项,只考虑剩余的式子。设
即递推前(n)个点的总(c_i)为(m)的所有情况之和。转移就有:
答案就是(g_{n,n-2} imes frac{(n-2)!}{prod_{i=1}^na_i})。
现在就可以有(20)分的好成绩了。如果用(NTT)实现上面的转移就可以有(40)分的好成绩。
然后发现我这个式子并不好优化(懒得优化两个式子)。瞟一眼题解之后发现开始那个式子可以化得好看些:
我们还是不管前面的常数项。可以发现每个(c_i)对式子的贡献就是(frac{a_i^{c_i}(c_i+1)^m}{c_i!})或者一个序列中仅有一个(c_i)贡献为(frac{a_i^{c_i}(c_i+1)^{2m}}{c_i!}),那么这个式子就可以由若干个次数表示(c)的(EGF)乘起来(实际上如果尝试用(EGF)推一下那个(n^3)递推可以更容易发现这种性质)。乍一看一共有(n)个系数不同的(n)次多项式,似乎不可做,但是第(i)个多项式每一项都有(a_i)的若干次方,且与(x)次数相同,所以这些多项式都可以写成(F(a_ix))的形式。因此设
答案就是(sum_{i=1}^nfrac{A(a_ix)}{B(a_ix)}prod_{j=1}^nB(a_jx))。这样有什么好处呢?这里补充一下这种trick。
如果式子可以写成(sum_{i=1}^nF(a_ix))的形式,并且对任意(m)都求出了(sum_{i=1}^na_i^m),那么只要求出(F(x)),式子就可以变成
因此我们求出(frac{A(x)}{B(x)})和(prod_{i=1}^nB(a_ix))即可。但后面这个是(prod),和前面的(sum)不同,这里就要取个(ln):
那么我们算一下(ln{B(x)})就可以像上面那样算了。
现在唯一的问题就是怎么对每个(m)求出(sum_{i=1}^na_i^m)。类似于自然数幂和的推导,我们写出这个东西的(OGF)就有
这就有点像P4705玩游戏这题的技巧,因为
那设(H(x)=sum_{i=1}^n(ln(1-a_ix))'),那(G(x)=-xH(x)+n)。求(H)就:
分治(NTT)即可。至此这题就解决了,复杂度瓶颈为最后分治(NTT)的(mathcal{O}(nlog^2n))。
#include<bits/stdc++.h>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i, a, b) for(int i = (a), ed = (b); i <= ed; ++i)
#define fb(i, a, b) for(int i = (a), ed = (b); i >= ed; --i)
#define go(u) for(int i = head[u]; ~i; i = e[i].nxt)
using namespace std;
typedef cn int cint;
typedef long long LL;
il void rd(int &x){
x = 0;
rg int f(1); rg char c(gc);
while(c < '0' || '9' < c){if(c == '-')f = -1; c = gc;}
while('0' <= c && c <= '9')x = (x<<1)+(x<<3)+(c^48), c = gc;
x *= f;
}
cint maxn = 30010, mod = 998244353, G = 3, invG = (mod+1)/3;
int n, m, a[maxn], fac[maxn], ifac[maxn], inv[maxn], mul = 1;
int lim, hst, rev[maxn<<2], A[maxn<<2], B[maxn<<2], ln[maxn<<2], invB[maxn<<2];
int H[maxn<<2], E[maxn<<2];
il int fpow(int a, int b, int ans = 1){
for(; b; b >>= 1, a = 1ll*a*a%mod)if(b&1)ans = 1ll*ans*a%mod;
return ans;
}
il void ntt(int *a, cint &typ){
fp(i, 0, lim-1)if(i > rev[i])swap(a[i], a[rev[i]]);
for(rg int md = 1; md < lim; md <<= 1){
rg int len = md<<1, Gn = fpow(typ ? invG : G, (mod-1)/len);
for(rg int l = 0; l < lim; l += len){
for(rg int nw = 0, Pow = 1; nw < md; ++nw, Pow = 1ll*Pow*Gn%mod){
rg int x = a[l+nw], y = 1ll*a[l+nw+md]*Pow%mod;
a[l+nw] = (x+y)%mod, a[l+nw+md] = (x-y+mod)%mod;
}
}
}
if(typ){
rg int inv = fpow(lim, mod-2);
fp(i, 0, lim-1)a[i] = 1ll*a[i]*inv%mod;
}
}
il void init(int n){
lim = 1, hst = 0;
while(lim < n)lim <<= 1, ++hst;
fp(i, 0, lim-1)rev[i] = (rev[i>>1]>>1)|((i&1)<<hst-1);
}
int inv_ary[maxn<<2];
void get_inv(int *a, int *f, int n){
if(n == 1)return f[0] = fpow(a[0], mod-2), void();
get_inv(a, f, n+1>>1), init(2*n-1);
fp(i, 0, n-1)inv_ary[i] = a[i];
ntt(f, 0), ntt(inv_ary, 0);
fp(i, 0, lim-1)f[i] = 1ll*f[i]*(2-1ll*f[i]*inv_ary[i]%mod+mod)%mod;
ntt(f, 1);
fp(i, n, lim-1)f[i] = 0;
fp(i, 0, lim-1)inv_ary[i] = 0;
}
int ln_ary[maxn<<2];
il void get_ln(int *a, int *f, int n){
get_inv(a, ln_ary, n), init(2*n-2);
fp(i, 1, n-1)f[i-1] = 1ll*a[i]*i%mod;
ntt(f, 0), ntt(ln_ary, 0);
fp(i, 0, lim-1)f[i] = 1ll*f[i]*ln_ary[i]%mod;
ntt(f, 1);
fp(i, n-1, lim)f[i] = 0;
fp(i, 0, lim)ln_ary[i] = 0;
fb(i, n-1, 1)f[i] = 1ll*f[i-1]*inv[i]%mod;
f[0] = 0;
}
int exp_ary[maxn<<2];
void get_exp(int *a, int *f, int n){
if(n == 1)return f[0] = 1, void();
get_exp(a, f, n+1>>1), get_ln(f, exp_ary, n), init(2*n-1);
fp(i, 0, n-1)exp_ary[i] = (a[i]-exp_ary[i]+mod)%mod;
if((++exp_ary[0]) == mod)exp_ary[0] = 0;
ntt(f, 0), ntt(exp_ary, 0);
fp(i, 0, lim-1)f[i] = 1ll*f[i]*exp_ary[i]%mod;
ntt(f, 1);
fp(i, n, lim-1)f[i] = 0;
fp(i, 0, lim-1)exp_ary[i] = 0;
}
int div_ary[19][maxn<<2];
void divntt(int d, int l, int r){
if(l == r)return div_ary[d][0] = 1, div_ary[d][1] = mod-a[l], void();
int md = l+r>>1, len = r-l+2;
divntt(d, l, md), divntt(d+1, md+1, r), init(len), ntt(div_ary[d], 0), ntt(div_ary[d+1], 0);
fp(i, 0, lim-1)div_ary[d][i] = 1ll*div_ary[d][i]*div_ary[d+1][i]%mod;
ntt(div_ary[d], 1);
fp(i, len, lim-1)div_ary[d][i] = 0;
fp(i, 0, lim-1)div_ary[d+1][i] = 0;
}
int main(){
// freopen("in", "r", stdin);
rd(n), rd(m);
fp(i, 1, n)rd(a[i]), mul = 1ll*mul*a[i]%mod;
fac[0] = 1; fp(i, 1, n)fac[i] = 1ll*fac[i-1]*i%mod;
ifac[n] = fpow(fac[n], mod-2); fb(i, n, 1)ifac[i-1] = 1ll*ifac[i]*i%mod;
inv[1] = 1; fp(i, 2, n)inv[i] = mod-1ll*(mod/i)*inv[mod%i]%mod;
fp(i, 0, n)A[i] = 1ll*fpow(i+1, m<<1)*ifac[i]%mod;
fp(i, 0, n)B[i] = 1ll*fpow(i+1, m)*ifac[i]%mod;
get_ln(B, ln, n+1), get_inv(B, invB, n+1), init(2*n+1);
// fp(i, 0, n)printf("%d ", B[i]);puts("");
// fp(i, 0, n)printf("%d ", ln[i]);puts("");
// fp(i, 0, n)printf("%d ", invB[i]);puts("");
// fp(i, 0, n)printf("%d ", A[i]);puts("");
// fp(i, 0, n)printf("%d ", invB[i]);puts("");
ntt(invB, 0), ntt(A, 0);
fp(i, 0, lim-1)A[i] = 1ll*A[i]*invB[i]%mod;
ntt(A, 1);
fp(i, n+1, lim)A[i] = 0;
// fp(i, 0, n)printf("%d ", A[i]);puts("");
divntt(0, 1, n), get_ln(div_ary[0], H, n+1);
// fp(i, 0, n)printf("%d ", div_ary[0][i]);puts("");
fp(i, 1, n)H[i] = mod-1ll*H[i]*i%mod;
H[0] = n;
fp(i, 0, n)ln[i] = 1ll*ln[i]*H[i]%mod;
get_exp(ln, E, n+1);
fp(i, 0, n)A[i] = 1ll*A[i]*H[i]%mod;
init(2*n+1), ntt(E, 0), ntt(A, 0);
fp(i, 0, lim-1)A[i] = 1ll*A[i]*E[i]%mod;
ntt(A, 1), printf("%lld
", 1ll*fac[n-2]*mul%mod*A[n-2]%mod);
return 0;
}