FFT:https://www.luogu.com.cn/problem/P3803
NTT:https://www.luogu.com.cn/problem/P4245
日常抄 oi-wiki
FFT
多项式乘法对多项式的系数向量进行卷积:(sum_{i=1}^{n+m}sum_{j=1}^i f(j)g(i-j))
DFT 可以认为是将多项式的系数表示法变成点值表示法,然后将点值分别相乘即可得到 (f imes g) 的点值表示,再用 IDFT 变换回去即可
而 FFT 便是快速实现 DFT 和 IDFT 的过程
定义 (x^n=1) 复数域的 (n) 个解为 (n) 次单位复根,表示为 (w_n)。有 (w_n=e^{frac{2pi i}{n}}=cos(frac{2pi}{n})+isin(frac{2pi}{n})),也就是把复平面下的单位圆均分成了 (n) 份
那么 (x^n=1) 的解就可以表示为 (w_n^k,k=0,1,2,cdots,n-1)
有性质:
- (w_n^n=1)
- (w_n^k=w_{2n}^{2k})
- (w_{2n}^k=-w_{2n}^{k+n})
DFT 在处理时将多项式按照次数的奇偶分治来求在每个 (w_n^k) 处的取值,比如有多项式:(f(x)=sum_{i=0}^7 a_ix^i)
那么按照奇偶分类:(f(x)=(a_0+a_2x^2+a_4x^4+a_6x^6)+x(a_1+a_3x^2+a_5x^4+a_7x^6))
然后分别定义 (g(x)=a_0+a_2x+a_4x^2+a_6x^3),h(x)=a_1+a_3x+a_5x^2+a_7x^3)
所以原式变成了 (g(x^2)+xcdot h(x^2))
如果分治的求出 (g(x),h(x)) 的点值,那么:
然后现在已经有了一个递归的做法,需要多项式的长度为 (n=2^k),所以只要对输入的数补齐到第一个大于等于它的 2 的幂次就行
然后带入 (w_n^0,w_n,w_n^2,cdots,w_n^{n-1}) 求值即可
更常用的是迭代版的
考虑分治的过程,当次数为 8 时,分治过程如下:
step 1: 0,1,2,3,4,5,6,7
step 2: 0,2,4,6; 1,3,5,7
step 3: 0,4; 2,6; 1,5; 3,7;
step 4: 0; 4; 2; 6; 1; 5; 3; 7;
然后发现并没有任何规律(?)
但如果写成二进制,就是:
000,001,010,011,100,101,110,111
变化成了
000,100,010,110,001,101,011,111
发现他恰好是把二进制的每一位反过来,这被称为 位逆序变换 或 蝴蝶变换
然后设 (rev_i) 为第 (i) 为要交换到的地方,那么讨论 (i) 的最后一个二进制位是 (0) 或 (1) 即可得知 (rev_i) 的最高二进制位,然后再根据 (rev_{i/2}) 的值进行递推:
for(reg int i=0;i<n;i++) rev[i]=rev[i>>1]>>1,rev[i]|=(i&1)?(n>>1):0;
然后提前交换就有了迭代版的 FFT
然后就是它的逆变换,将点值表示转化为系数表示
(f(x)) 组成的向量可以由系数组成的向量乘一个关于带入的值(也就是 (w_n^k))的矩阵得出,而 FFT 的过程实际上就是快速求了这样的一个矩阵和向量的乘积,具体来说就是:
现在知道等号左边的向量和乘的矩阵,要反推右边的系数向量,那么只要给中间的矩阵求逆即可
那么直接套用 (O(n^3)) 求逆的方法
其实这个矩阵的逆矩阵可以直接构造出来,每一位取倒数再除以 (n) 即可
然后只要给左边的 (f(x)) 向量用 (w_n^{-k}=cos(frac{2pi k}{n})-isin(frac{2pi k}{n})) 构成的矩阵再做一次 FFT 就是逆变换了
最终在结果上再除以 (n) 就行
const double PI=acos(-1);
struct Complex{
double a,b;
}a[N];
inline Complex operator + (Complex x,Complex y){return (Complex){x.a+y.a,x.b+y.b};}
inline Complex operator - (Complex x,Complex y){return (Complex){x.a-y.a,x.b-y.b};}
inline Complex operator * (Complex x,Complex y){return (Complex){x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a};}
int rev[N];
inline void init(reg int n){
for(reg int i=0;i<n;i++) rev[i]=rev[i>>1]>>1,rev[i]|=(i&1)?(n>>1):0;
}
inline void fft(int n,Complex *a,int type){
for(reg int i=0;i<n;i++)if(rev[i]<i) std::swap(a[i],a[rev[i]]);
Complex wn,w,o;
for(reg int h=1;h<n;h<<=1){
wn.a=cos(2*PI/(h<<1));wn.b=sin((type?2:-2)*PI/(h<<1));
for(reg int i=0;i<n;i+=h<<1){
w.a=1;w.b=0;
for(reg int j=i;j<i+h;j++,w=w*wn){
o=w*a[j+h];
a[j+h]=a[j]-o;a[j]=a[j]+o;
}
}
}
if(!type) for(reg int i=0;i<n;i++) a[i].a/=n,a[i].b/=n;
}
NTT
如果求得的系数需要取模就不能直接 FFT 了
所以考虑到对于质数 (p=qn+1,n=2^m),可以取其原根 (g),满足 (g^{qn}equiv 1 pmod p),然后就可以将 (g_n=g^q) 等价的看作 (w_n)
更进一步,(g_k=g^{(p-1)/k}) 可以等价的看作 (w_k)
发现它也满足之前说的 (w_n) 的性质:(g_n^nequiv 1pmod p,g_{2n}^{2k}equiv g_{n}^kpmod p,g_{2n}^kequiv g_{2n}^{n+k}pmod p)
所以就如果选取的质数有足够多的 (2) 的因数,就可以直接把 FFT 中的 (w_n) 换成 (g^{(p-1)/n}) 来计算
满足要求的质数有:998244353,1004535809,469762049
对于任意模数 NTT,可以直接选取以上三个质数分别做一次普通 NTT,然后再 CRT 合并,值域一般够用
放一个任意模数 NTT 的代码
#define G 3
#define mod1 998244353
#define mod2 1004535809
#define mod3 469762049
#define N 270006
int n,m,p;
inline int power(reg int a,reg int b,reg int mod){
reg int ans=1;
while(b){
if(b&1) ans=(long long)ans*a%mod;
b>>=1;a=(long long)a*a%mod;
}
return ans;
}
const int inv1=power(mod1,mod2-2,mod2),inv2=power((long long)mod1*mod2%mod3,mod3-2,mod3);
inline int Mod(reg int a,reg int mod){return a>=mod?a-mod:a;}
int rev[N];
int f1[N],f2[N],f3[N],g1[N],g2[N],g3[N];
inline void init(reg int n){
for(reg int i=0;i<n;i++) rev[i]=rev[i>>1]>>1,rev[i]|=(i&1)?(n>>1):0;
}
inline void ntt(reg int n,int *a,int type,int mod){
for(reg int i=0;i<n;i++)if(rev[i]<i) std::swap(a[i],a[rev[i]]);
for(reg int h=1;h<n;h<<=1){
int gn=power(G,(mod-1)/(h<<1),mod),g,o;
if(!type) gn=power(gn,mod-2,mod);
for(reg int i=0;i<n;i+=h<<1){
g=1;for(reg int j=i;j<i+h;j++,g=(long long)g*gn%mod){
o=(long long)g*a[j+h]%mod;
a[j+h]=Mod(a[j]-o+mod,mod);a[j]=Mod(a[j]+o,mod);
}
}
}
if(!type){
int inv=power(n,mod-2,mod);
for(reg int i=0;i<n;i++) a[i]=(long long)a[i]*inv%mod;
}
}
inline int get(long long a,long long b,long long c){
long long x=(b-a+mod2)%mod2*inv1%mod2*mod1+a;
return ((c-x%mod3+mod3)%mod3*inv2%mod3*mod1%p*mod2%p+x)%p;
}
int main(){
n=read();m=read();p=read();
for(reg int i=0;i<=n;i++) f1[i]=f2[i]=f3[i]=read()%p;
for(reg int i=0;i<=m;i++) g1[i]=g2[i]=g3[i]=read()%p;
int max=1;while(max<=n+m) max<<=1;
init(max);
ntt(max,f1,1,mod1);ntt(max,f2,1,mod2);ntt(max,f3,1,mod3);
ntt(max,g1,1,mod1);ntt(max,g2,1,mod2);ntt(max,g3,1,mod3);
for(reg int i=0;i<max;i++)
f1[i]=(long long)f1[i]*g1[i]%mod1,f2[i]=(long long)f2[i]*g2[i]%mod2,f3[i]=(long long)f3[i]*g3[i]%mod3;
ntt(max,f1,0,mod1);ntt(max,f2,0,mod2);ntt(max,f3,0,mod3);
for(reg int i=0;i<=n+m;i++) printf("%d ",get(f1[i],f2[i],f3[i]));
return 0;
}