原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ450.html
题解
首先有一个东西叫做“单位根反演”,它在 FFT 的时候用到过:
$$frac 1 n sum_{i=0}^{n-1} omega_n ^{dcdot i} = [n|d]$$
其中 $omega_n$ 表示 $n$ 次单位根。
接下来我们回到本题。
我们来搞一个指数生成函数,第 i 项表示总共复读 i 次,使得一个复读机开心的方案。
$$f(x) = sum_{igeq 0} [d|i] frac{x^i} {i!}$$
那么我们要求的东西就是:
$$f^k(x)[n]$$
我们来给 $f(x)$ 推一推式子:
$$f(x) = sum_{igeq 0} [d|i] frac{x^i} {i!} \ = sum_{igeq 0} left( frac 1 d sum_{j=0}^{d-1} omega_d^{ij} ight) frac{x^i}{i!}\ = frac 1 d sum_{j=0}^{d-1} sum_{igeq 0} frac{x^i (omega _d ^j)^i}{i!}\ = frac 1 d sum_{i=0}^{d-1} e^{omega_d ^i x}$$
在 NTT 的时候,我们用原根的幂来代替单位根。
我们发现 19491001 是个质数,它的最小原根是 7 ,而且 19491000 = 2*2*2*3*5*5*5*73*89,含有因子 2 和 3,这说明能找到原根来分别代替 $omega _2$ 和 $omega_3$ 。
接下来我们分情况讨论:
d = 1 : ans = $k^n$ 。
d = 2 :
$$f(x) = frac 12 (e^x + e^{-x})$$
$$f^k(x)[n] = (frac 1 2 )^ksum_{i=0}^{k} inom k i c^n $$
由于 $kleq 500000$,直接爆算就好了。
d = 3 :
$$f(x) = frac 13 (e^x + e^{omega_3 x } + e^{omega_3^2 x })$$
注意到由于 d = 3 时, $kleq 1000$ ,所以和 $d = 2 $ 的情况差不多,暴力展开 2 层就好了。
具体怎么做直接看代码吧。懒得码式子了。
时间复杂度 $O(k^{d-1})$ 。
代码
#pragma GCC optimize("Ofast","inline") #include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof (x)) #define For(i,a,b) for (int i=a;i<=b;i++) #define Fod(i,b,a) for (int i=b;i>=a;i--) #define pb push_back #define mp make_pair #define fi first #define se second #define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I') #define outval(x) printf(#x" = %d ",x) #define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("") #define outtag(x) puts("----------"#x"----------") #define outarr(a,L,R) printf(#a"[%d...%d] = ",L,R); For(_v2,L,R)printf("%d ",a[_v2]);puts(""); using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=500005,mod=19491001,G=7; int Pow(int x,int y){ int ans=1; for (;y;y>>=1,x=(LL)x*x%mod) if (y&1) ans=(LL)ans*x%mod; return ans; } int n,k,d; int Fac[N],Inv[N]; void prework(){ int n=N-1; for (int i=Fac[0]=1;i<=n;i++) Fac[i]=(LL)Fac[i-1]*i%mod; Inv[n]=Pow(Fac[n],mod-2); Fod(i,n,1) Inv[i-1]=(LL)Inv[i]*i%mod; } int C(int n,int m){ if (m>n||m<0) return 0; return (LL)Fac[n]*Inv[m]%mod*Inv[n-m]%mod; } int main(){ prework(); n=read(),k=read(),d=read(); if (d==1){ cout<<Pow(k,n)<<endl; } else if (d==2){ int ans=0; For(i,0,k){ int c=(i-(k-i)+mod)%mod; ans=((LL)C(k,i)*Pow(c,n)+ans)%mod; } ans=(LL)ans*Pow(2,mod-1-k)%mod; cout<<ans<<endl; } else { int ans=0; int w0=1,w1=Pow(G,(mod-1)/3),w2=(LL)w1*w1%mod; For(i,0,k) For(j,0,k-i){ int c=((LL)w0*i+(LL)w1*j+(LL)w2*(k-i-j))%mod; ans=((LL)C(k,i)*C(k-i,j)%mod*Pow(c,n)+ans)%mod; } ans=(LL)ans*Pow(3,mod-1-k)%mod; cout<<ans<<endl; } return 0; }