NTT
NTT是一种跑得比FFT快的东西(?)。
元素的幂
考虑有限群G,(a in G)。元素的幂就是a的几次方。
使得(a^d=e)的最小正整数d称为a的阶,记作(d=ord(a))。
显然,a的幂生成的集合S是G的子群。因此,(a^{|G|}=e)。
原根
有个结论:(Z_n^*)存在原根(Leftrightarrow)(n=2,4,p^alpha,2p^alpha),p为奇素数。
tip:469762049,998244353和1004535809都有原根3。
设对于n,我们找到了原根g。设(g_n=g^{frac{p-1}{n}})。那么:
-
(g_n=g^{frac{p-1}{n}})
-
(g^n_n=g^{p-1}=1)
-
(g_{dn}^{dk}=(g^frac{p-1}{dn})^{dk}=(g^{frac{p-1}{n}})^k=g_n^k)(消去引理)
-
((g_n^k)^2=(g_n^{k+n/2})^2=(g^frac{p-1}{n/2})^k=g_{n/2}^k)(折半引理)
-
求和引理:(sum_{i=0}^{n-1}(g_n^k)^i =left{ egin{align} n,nmid k \ 0,n mid k end{align} ight.), 可以用类似fft的方法证。
然后我们就可以用原根愉快的做fft了。
inline void plus(int &x, int y){ x=1ll*x*y%mod; }
inline void plus(LL &x, LL y){ x*=y; x%=mod; }
inline void pro(int &x){ if (x<0) x+=mod; }
int fpow(LL a, LL x){
LL ans=1;
for (LL base=a; x; x>>=1, plus(base, base))
if (x&1) plus(ans, base);
return ans;
}
int inv(int x){ return fpow(x, mod-2); }
void fft(int *a, int l, int flag){
for (int i=0; i<l; ++i)
if (i<rev[i]) swap(a[i], a[rev[i]]);
LL gn, g, x, y;
for (int mid=1; mid<l; mid<<=1){ //区间半径
gn=fpow(G, (mod-1)/(mid<<1));
if (flag==-1) gn=inv(gn);
for (int j=0; j<l; j+=(mid<<1)){ g=1;
for (int k=j; k<j+mid; ++k, plus(g, gn)){
x=a[k]; y=g*a[k+mid]%mod;
a[k]=(x+y)%mod; a[k+mid]=(x-y+mod)%mod;
}
}
}
}
void ntt(int *a, int *b, int &la, int &lb){ //a=a*b
int l=1, bits=0; while (l<=la+lb) l<<=1, ++bits;
int linv=inv(l);
for (int i=1; i<l; ++i)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bits-1));
fft(a, l, 1); fft(b, l, 1);
for (int i=0; i<l; ++i) a[i]=1ll*a[i]*b[i]%mod;
fft(a, l, -1); la=la+lb;
for (int i=0; i<=la; ++i) a[i]=1ll*a[i]*linv%mod;
}
归纳一下写fft的思路。首先,应该明确我们要写的主要部分,是把系数表示转换成点值表示。先推式子,对于一个系数表示,把它拆成偶数和奇数的两个多项式。然后,把当前多项式写成两个子多项式的式子。利用单位根的性质,使得分叉个数为4。注意数组大小要开4n。
为了方便,把ntt给封装起来。
任意模数ntt
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int maxn=4e5+5, maxc=1e5+5; //maxn必须是4倍数组
const LL p1=998244353, p2=1004535809, p3=469762049, G=3;
int n, m, k, cntc, rev[maxn], mod, P;
inline void plus(int &x, int y){ x=1ll*x*y%mod; }
inline void plus(LL &x, LL y){ x*=y; x%=mod; }
inline void pro(int &x){ if (x<0) x+=mod; }
LL fmul(LL a, LL b, LL mod){
LL ans=0;
for (; b; b>>=1, a+=a, a%=mod)
if (b&1) ans+=a, ans%=mod;
return ans;
}
int fpow(LL a, LL x, LL mod){
LL ans=1;
for (LL base=a; x; x>>=1, (base*=base)%=mod)
if (x&1) (ans*=base)%=mod;
return ans;
}
int inv(int x){ return fpow(x, mod-2, mod); }
void fft(int *a, int l, int flag){
for (int i=0; i<l; ++i)
if (i<rev[i]) swap(a[i], a[rev[i]]);
LL gn, g, x, y;
for (int mid=1; mid<l; mid<<=1){ //区间半径
gn=fpow(G, (mod-1)/(mid<<1), mod);
if (flag==-1) gn=inv(gn);
for (int j=0; j<l; j+=(mid<<1)){ g=1;
for (int k=j; k<j+mid; ++k, plus(g, gn)){
x=a[k]; y=g*a[k+mid]%mod;
a[k]=(x+y)%mod; a[k+mid]=(x-y+mod)%mod;
}
}
}
}
void ntt(int *a, int *b, int &la, int &lb){ //a=a*b
int l=1, bits=0; while (l<=la+lb) l<<=1, ++bits;
int linv=inv(l);
for (int i=1; i<l; ++i)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bits-1));
fft(a, l, 1); fft(b, l, 1);
for (int i=0; i<l; ++i) a[i]=1ll*a[i]*b[i]%mod;
fft(a, l, -1);
for (int i=0; i<=la+lb; ++i) a[i]=1ll*a[i]*linv%mod;
}
int A[maxn], B[maxn], C[3][maxn], D[3][maxn];
int crt(LL c1, LL c2, LL c3){
static LL invp1=fpow(p1, p2-2, p2), invp2=fpow(p2, p1-2, p1);
static LL p0=p1*p2, invp0=fpow(p0%p3, p3-2, p3);
LL c0=fmul(p2*invp2, c1, p0)+fmul(p1*invp1, c2, p0); c0%=p0;
LL k=(c3+p3-c0%p3)*invp0%p3;
return (c0%mod+k*(p0%mod))%mod; //k*p0会爆!
}
int main(){
scanf("%d%d%d", &n, &m, &P);
for (int i=0; i<=n; ++i)
scanf("%d", &A[i]), C[0][i]=C[1][i]=C[2][i]=A[i];
for (int i=0; i<=m; ++i)
scanf("%d", &B[i]), D[0][i]=D[1][i]=D[2][i]=B[i];
mod=p1; ntt(C[0], D[0], n, m);
mod=p2; ntt(C[1], D[1], n, m);
mod=p3; ntt(C[2], D[2], n, m); mod=P;
for (int i=0; i<=n+m; ++i)
printf("%d ", crt(C[0][i], C[1][i], C[2][i]));
return 0;
}
两天后的PS:注意对于998244352,1004535808,469762048的分解。(998244352=2^{23}*x),(1004535808=2^{21}*x),(469762048=2^{26}*x)。这是因为三个质数本来就是(p=c*2^l+1)的形式。