很早就想学FFT那套理论,但抱着能咕一天是一天的态度咕到了今天
fft是干什么的?
求两个多项式卷积的,比如$g=a*b$($g_x=sum{a_i*b_{x-i}}$)
显然暴力乘是$O(n^2)$的,然而我们可以把他优化到$O(n;log;n)$
一般来将,多项式是用每一项的系数表示的,而还可以用点值来表示,比如一个多项式$a$有n项,我们可以让变量x取n个不同的值,然后用n个得出来的值表示这个多项式
由于是点值,两个多项式相乘时只要把对应的点值相乘即可,这是$O(n)$的
显然我们容易可以把两个表示法互相转换,比如$n^2$暴力和$n^3$高斯消元
这似乎比暴力还要不优越,所以我们要优化转化表示法的复杂度
我们发现我们可以随便选数x,只要x各不相同就行,然而当x在大部分取值时,都要暴力算$x^i$的值,这很不优越
所以我们的数学知识告诉我们数不只有实数,还有虚数啊懒得介绍虚数,我去拖一点东西过来
虚数大概就是可以表示在一个复平面上的东西,k次单位根$w_k$就是其k次方是1的东西$w_k$的i次方都在一个以原点为圆心,1为半径的圆上
由于复数运算的种种性质(幅角相加,长度相乘),这些东西是绕原点顺时针的,然后我们把i次单位根代入多项式来求值,这就是可以优化的了
然后我们就可以用分治来优化啦
把多项式分成奇数项和偶数项两部分然后分治,就像这样
愉快地盗了张图来 原文
由于要按照奇偶性分治,分治后的顺序会和原来不同,为了让实现更方便,可以把原下标的二进制翻转后当新下标(不会证)
那么我们就可以把系数表示变成点值表示了,这就是DFT
然后怎么把点值还原成系数呢,据说只要把点值除第一项的、部分翻转,然后跑DFT,再把结果除以n就好了(还是不会证)
大概就这样吧我感觉讲的海星
例题有这个(模板)多项式乘法(FFT)
AC代码
#include<bits/stdc++.h> using namespace std; #define int long long int n,m,i,j,len; int w[5000000],f[5000000],g[5000000]; int rev[5000000],bit,ha=998244353; inline void getrev(int n){for(i=0;i<(1<<n);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));} int add(int x){return (x>=ha)?x-ha:x;} int jian(int x){return (x<0)?x+ha:x;} inline int qpow(int a,int b) { int ans=1; while(b){if(b%2)ans=ans*a%ha;a=a*a%ha,b/=2;}; return ans; } void fft(int *a,int n,int x) { if(x)reverse(a+1,a+n); for(int i=0;i<n;i++)if(rev[i]>i)swap(a[i],a[rev[i]]); for(int i=1;i<n;i<<=1) { w[0]=1; w[1]=qpow(3,(ha-1)/(i<<1)); for(int j=2;j<i;j++)w[j]=w[j-1]*w[1]%ha; for(int j=0;j<n;j+=(i<<1)) for(int k=j;k<j+i;k++) { int x=a[k],y=a[k+i]*w[k-j]%ha; a[k]=add(x+y),a[k+i]=jian(x-y); } } int ni=qpow(n,ha-2); if(x)for(int i=0;i<n;i++)a[i]=a[i]*ni%ha; } signed main() { scanf("%lld%lld",&n,&m); for(i=0;i<=n;i++)scanf("%lld",&f[i]); for(i=0;i<=m;i++)scanf("%lld",&g[i]); len=1; for(;len<m+n+1;len=len<<1)bit++;getrev(bit); fft(f,len,0);fft(g,len,0); for(i=0;i<len;i++)f[i]=f[i]*g[i]%ha; fft(f,len,1); for(i=0;i<=n+m;i++)printf("%lld ",f[i]); return 0; }
fft还有模意义下的版本NTT
只要把单位根变成模数的原根就行了,一般都是3,如998244353(1e9+7不是NTT模数)
如果要求没有原根的多项式乘法,可以用CRT把有原根的答案结合起来
再来道例题
AC代码
#include<bits/stdc++.h> using namespace std; #define int long long int n,m,i,j,len; int w[2100000],f[2100000],g[2100000]; int rev[2100000],bit,ha=998244353; inline void getrev(int n){for(i=0;i<(1<<n);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));} int add(int x){return (x>=ha)?x-ha:x;} int jian(int x){return (x<0)?x+ha:x;} inline int qpow(int a,int b) { int ans=1; while(b){if(b%2)ans=ans*a%ha;a=a*a%ha,b/=2;}; return ans; } void fft(int *a,int n,int x) { if(x)reverse(a+1,a+n); for(int i=0;i<n;i++)if(rev[i]>i)swap(a[i],a[rev[i]]); for(int i=1;i<n;i<<=1) { w[0]=1; w[1]=qpow(3,(ha-1)/(i<<1)); for(int j=2;j<i;j++)w[j]=w[j-1]*w[1]%ha; for(int j=0;j<n;j+=(i<<1)) for(int k=j;k<j+i;k++) { int x=a[k],y=a[k+i]*w[k-j]%ha; a[k]=add(x+y),a[k+i]=jian(x-y); } } int ni=qpow(n,ha-2); if(x)for(int i=0;i<n;i++)a[i]=a[i]*ni%ha; } signed main() { scanf("%lld%lld",&n,&m); for(i=0;i<=n;i++)scanf("%lld",&f[i]); for(i=0;i<=m;i++)scanf("%lld",&g[i]); len=1; for(;len<m+n+1;len=len<<1)bit++;getrev(bit); fft(f,len,0);fft(g,len,0); for(i=0;i<len;i++)f[i]=f[i]*g[i]%ha; fft(f,len,1); for(i=0;i<=n+m;i++)printf("%lld ",f[i]); return 0; }