其实有原题,生成树计数
然鹅这题里面是两道题, 50pts 可以用上面那题的做法直接过掉,另外 50pts 要推推式子,搞出 O n 的做法才行(毕竟多项式常数之大您是知道的)
虽说这道题里面是没有 a_i 的,也不用分治合并多项式的就是了,所以大致思路看我另一题的题解就好了,这里对于前 50pts 的做法只给出式子:
[ANS_n= {(n-2)! Big( [x^{n-2}] ig(sum_{i=0}^infty (i+1) ^m {x^i over i! } ig)^n Big)over n^{n-2}}
]
我们先康康我们原本要求的多项式变成了什么:
[[x^{n-2}] ig(sum_{i=0}^infty (i+1) {x^iover i!} ig)^n
]
然后我们就考虑转成 EXP 咯
[egin{aligned} &[x^{n-2}]Big(sum_{i=0}^infty (i+1) {x^iover i!} Big)^n\=& [x^{n-2}]Big(e^x(x+1)Big)^n \=&[x^{n-2}] e^{nx}·(x+1)^n \=& sum_{i=2}^{n} {n^{i-2}over (i-2)!} ·{n!over (n-i)!· i!} end{aligned}
]
注意,这里乱转 EXP 的时候千万要记得运算,不然就像我一样多加了一个 -x 然后死都化不出来了
然后咱预处理完 阶乘 及其 逆元 就可以 O n 出解了
//by Judge
#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
#define Rg register
#define fp(i,a,b) for(Rg int i=(a),I=(b)+1;i<I;++i)
#define fd(i,a,b) for(Rg int i=(a),I=(b)-1;i>I;--i)
#define ll long long
using namespace std;
const int mod=998244353;
const int iG=332748118;
const int M=5e6+3;
typedef int arr[M];
#ifndef Judge
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#endif
char buf[1<<21],*p1=buf,*p2=buf;
inline int inc(int x,int y){return (x+=y)>=mod?x-mod:x;}
inline int dec(int x,int y){return (x-=y)<0?x+mod:x;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline int read(){ int x=0,f=1; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f;
} int n,m,res,limit; arr fac,finv,A,B,C,r;
inline int qpow(Rg int x,Rg int p=mod-2,int s=1){
for(;p;p>>=1,x=mul(x,x)) if(p&1) s=mul(s,x); return s;
}
inline void init(int n){ int l=-1;
for(limit=1;limit<n;limit<<=1)++l;
fp(i,0,limit-1) r[i]=(r[i>>1]>>1)|((i&1)<<l);
}
inline void NTT(int* a,int tp){
fp(i,0,limit-1) if(i<r[i]) swap(a[i],a[r[i]]);
for(Rg int mid=1;mid<limit;mid<<=1){
int Gn=qpow(tp?3:iG,(mod-1)/(mid<<1));
for(Rg int j=0,I=mid<<1,x,y;j<limit;j+=I)
for(Rg int k=0,g=1;k<mid;++k,g=mul(g,Gn))
x=a[j+k],y=mul(a[j+k+mid],g),
a[j+k]=(x+y)%mod,a[j+k+mid]=(x-y+mod)%mod;
} if(tp) return; int inv=qpow(limit);
fp(i,0,limit-1) a[i]=mul(a[i],inv);
}
void Inv(int* a,int* b,int n){ static arr C,D;
if(n==1) return b[0]=qpow(a[0]),void();
Inv(a,b,n>>1),init(n<<1);
fp(i,0,n-1) C[i]=a[i],D[i]=b[i];
fp(i,n,limit-1) C[i]=D[i]=0; NTT(C,1),NTT(D,1);
fp(i,0,limit-1) C[i]=mul(C[i],mul(D[i],D[i]));
NTT(C,0); fp(i,n,limit-1) b[i]=0;
fp(i,0,n-1) b[i]=dec(inc(b[i],b[i]),C[i]);
}
inline void Direv(int* a,int* b,int n){
fp(i,1,n-1) b[i-1]=mul(a[i],i); b[n-1]=0;
}
inline void Inter(int* a,int* b,int n){
fp(i,1,n-1) b[i]=mul(a[i-1],qpow(i)); b[0]=0;
}
void Ln(int* a,int* b,int n){ static arr C,D;
Inv(a,C,n),Direv(a,D,n),init(n<<1);
fp(i,n,limit-1) C[i]=D[i]=0; NTT(C,1),NTT(D,1);
fp(i,0,limit-1) C[i]=mul(C[i],D[i]); NTT(C,0),Inter(C,b,n);
}
void Exp(int* a,int* b,int n){
if(n==1) return b[0]=1,void(); static arr B;
Exp(a,b,n>>1),Ln(b,B,n),B[0]=dec(a[0]+1,B[0]); init(n<<1);
fp(i,1,n-1) B[i]=dec(a[i],B[i]); fp(i,n,limit-1) B[i]=0;
NTT(B,1),NTT(b,1); fp(i,0,limit-1) b[i]=mul(b[i],B[i]);
NTT(b,0); fp(i,n,limit-1) b[i]=B[i]=0;
}
int main(){
/// pre calc
n=2e6,fac[0]=finv[0]=finv[1]=1;
fp(i,1,n) fac[i]=mul(fac[i-1],i);
fp(i,2,n) finv[i]=mul(mod-mod/i,finv[mod%i]);
fp(i,2,n) finv[i]=mul(finv[i-1],finv[i]);
fp(Stp,1,read()){ n=read(),m=read();
Rg int len=1; while(len<=n) len<<=1;
if(m==1){
Rg int x=1,ans=0;
fp(i,2,n) ans=inc(ans,mul(x,mul(fac[n],mul(finv[i-2],mul(finv[n-i],finv[i]))))),x=mul(x,n);
printf("%d
",mul(mul(fac[n-2],ans),qpow(qpow(n,n-2))));
} else{
fp(i,0,n) A[i]=mul(qpow(i+1,m),finv[i]); Ln(A,B,len);
fp(i,0,n) B[i]=mul(B[i],n),A[i]=0; Exp(B,A,len);
printf("%d
",mul(mul(fac[n-2],A[n-2]),qpow(qpow(n,n-2))));
memset(A,0,(len+2)<<3);
}
} return 0;
}