比较神仙的推导.
求 $sum_{n=0}^{ infty }s(n)r^n$,其中 $s(x)$ 是一个 $m$ 次多项式,$0leqslant r leqslant 1$
显然可以 $s(x)$ 每一个系数的贡献,那么就转化为:
$sum_{j=0}^{m} a_{j} sum_{n=0}^{infty} n^jr^n.$
令 $f_{j}=sum_{n=0}^{infty} n^jr^n$.
$(1-r)f_{j}=sum_{n=0}^{infty} n^jr^n-n^jr^{n+1}$
$Rightarrow sum_{n=1}^{infty} [n^j-(n-1)^j] r^n$
$Rightarrow rsum_{n=0}^{infty} [(n+1)^j-n^j]r^n$
二项式展开,得 $rsum_{i=0}^{j-1}inom{j}{i}sum_{n=0}^{infty}n^ir^n$
然后就可以写成分治 NTT 的形式了:
$frac{f_{j}}{j!}=sum_{i=0}^{j-1} frac{f_{i}}{i!} frac{r}{(j-i)!(1-r)}$
其中 $f_{0}=frac{1}{1-r}$
code:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 100009 #define ll long long #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int m,V; int inv[N],fac[N]; int A[N<<2],B[N<<2],f[N],g[N],seq[N]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) { if(y&1) tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } void NTT(int *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) swap(a[i],a[k]); for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { int wn=qpow(3,(mod-1)/(l<<1)); if(op==-1) { wn=get_inv(wn); } for(int i=0;i<len;i+=l<<1) { int w=1,x,y; for(int j=0;j<l;++j) { x=a[i+j],y=(ll)a[i+j+l]*w%mod; a[i+j]=(ll)(x+y)%mod; a[i+j+l]=(ll)(x-y+mod)%mod; w=(ll)w*wn%mod; } } } if(op==-1) { int in=get_inv(len); for(int i=0;i<len;++i) { a[i]=(ll)a[i]*in%mod; } } } void init() { fac[0]=1,inv[1]=1; for(int i=1;i<N;++i) fac[i]=(ll)fac[i-1]*i%mod; for(int i=2;i<N;++i) { inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; } inv[0]=1; for(int i=1;i<N;++i) { inv[i]=(ll)inv[i-1]*inv[i]%mod; } } void solve(int l,int r) { if(l==r) { return; } int mid=(l+r)>>1,lim,s1=0,s2=0; solve(l,mid); for(int i=l;i<=mid;++i) A[s1++]=f[i]; for(int i=0;i<=r-l;++i) B[s2++]=g[i]; for(lim=1;lim<(s1+s1);lim<<=1); for(int i=s1;i<lim;++i) A[i]=0; for(int i=s2;i<lim;++i) B[i]=0; NTT(A,lim,1),NTT(B,lim,1); for(int i=0;i<lim;++i) A[i]=(ll)A[i]*B[i]%mod; NTT(A,lim,-1); for(int i=mid+1;i<=r;++i) { (f[i]+=A[i-l])%=mod; } for(int i=0;i<lim;++i) A[i]=B[i]=0; solve(mid+1,r); } int main() { // setIO("input"); init(); scanf("%d%d",&m,&V); for(int i=0;i<=m;++i) { scanf("%d",&seq[i]); } f[0]=get_inv((ll)(1-V+mod)%mod); int in=(ll)V*f[0]%mod; for(int i=1;i<=m;++i) { g[i]=(ll)inv[i]*in%mod; } solve(0,m); for(int i=1;i<=m;++i) { f[i]=(ll)f[i]*fac[i]%mod; } int ans=0; for(int i=0;i<=m;++i) { (ans+=(ll)seq[i]*f[i]%mod)%=mod; } printf("%d ",ans); return 0; }