• [FFT/NTT/MTT]总结


    最近重新学了下卷积,简单总结一下,不涉及细节内容:

    1、FFT

    朴素求法:$Coefficient-O(n^2)-CoefficientResult$

    FFT:$Coefficient-O(nlogn)-Dot-O(n)-DotResult-O(nlogn)-CoefficientResult$

    其中系数到点值的转化称为$DFT(离散傅里叶变换)$,而点值到系数的转为称为$IDFT(傅里叶逆变换)$

    原本朴素的直接带入$n$个值的$DFT$和直接使用拉格朗日插值公式的$IDFT$的复杂度仍为$O(n^2)$

    但$FFT$通过带入特定的值:单位根,使得两者都能迭代/分治得解决,将复杂度降到了$O(nlogn)$

    优化的技巧和注意事项:

    1、预处理$w[i]$

    2、求出最终数组从后往前迭代省去递归常数

    3、数组长度要先扩成2的倍数用于分治

    模板:

    #include <bits/stdc++.h>
    
    using namespace std;
    #define X first
    #define Y second
    #define pb push_back
    typedef double db;
    typedef long long ll;
    typedef pair<int,int> P;
    const int MAXN=3e6+10;
    struct Complex
    {
        db x,y;
        Complex(db a=0,db b=0){x=a;y=b;}
        Complex operator + (const Complex& rhs)
        {return Complex(x+rhs.x,y+rhs.y);}
        Complex operator - (const Complex& rhs)
        {return Complex(x-rhs.x,y-rhs.y);}
        Complex operator * (const Complex& rhs)
        {return Complex(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);}
    }a[MAXN],b[MAXN];
    int n,m,lmt=1,dgt,par[MAXN];
    
    void FFT(Complex *a,int flag)
    {
        for(int i=0;i<lmt;i++)
            if(i<par[i]) swap(a[i],a[par[i]]);
        
        for(int len=1;len<lmt;len<<=1)
        {
            Complex unit(cos(M_PI/len),flag*sin(M_PI/len));
            for(int st=0;st<lmt;st+=(len<<1))
            {
                Complex w(1,0);
                for(int k=st;k<st+len;k++,w=w*unit)
                {
                    Complex A=a[k],B=w*a[k+len];
                    a[k]=A+B;a[k+len]=A-B;
                }
            }
        }
        if(flag==-1)
            for(int i=0;i<=n+m;i++)
                a[i].x=floor(a[i].x/lmt+0.5);
    }
    
    int main()
    {    
        scanf("%d%d",&n,&m);
        for(int i=0;i<=n;i++) scanf("%lf",&a[i].x);
        for(int i=0;i<=m;i++) scanf("%lf",&b[i].x);
        while(lmt<=n+m) lmt<<=1,dgt++;
        for(int i=0;i<lmt;i++)
            par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));    
        
        FFT(a,1);FFT(b,1);
        for(int i=0;i<lmt;i++) 
            a[i]=a[i]*b[i];
        FFT(a,-1);
        for(int i=0;i<=n+m;i++) 
            printf("%d ",(int)a[i].x);
        return 0;
    }
    FFT

    2、NTT

    单位根由于涉及了复数的运算,导致对精度要求高时会出错

    而$NTT$就能使得整个$FFT$都能在模意义下计算,从而满足精度要求

    考虑$FFT$引入单位根$w_n^k$是为了其什么性质来分治计算:

    1、$w_n^k$互不相同,保证点值表示的合法

    2、$w_{t*n}^{t*k}=w_n^k$且$w_n^{k+2/n}=-w_n^k$,使得计算可分治

    3、$sum_{i=0}^{n-1} {w_n^k}^i=n*[k==0]$,保证逆矩阵构造的正确性

    在模意义下引入质数$p=kn+1$,其原根$g$满足$g_t(tin [0,p-1])$互不相同

    这样令$p$的$k$次单位根为$g^{frac{p-1}{k}}$,易证上述$w_n^k$的性质其在模意义下均满足

    接下来考虑该怎样选择质数$p$

    为了能够分治时允许$k$每次乘2,$p-1$的质因数分解中要有很多的2

    令$p=r*2^k+1$,其能处理的数据规模为$[0,2^k]$,常用质数有:传送门

    这样,我们就在模意义下利用原根的性质找到了可做$FFT$的“单位根”

    由于没有了复数运算,$NTT$比$FFT$的常数也小了很多,一般是更好的选择

    模板:

    #include <bits/stdc++.h>
    
    using namespace std;
    #define X first
    #define Y second
    #define pb push_back
    typedef double db;
    typedef long long ll;
    typedef pair<int,int> P;
    const int MAXN=4e6+10,MOD=998244353;
    ll n,m,a[MAXN],b[MAXN],dgt,lmt=1,par[MAXN];
    
    ll quick_pow(ll a,ll b)
    {
        ll ret=1;
        for(;b;b>>=1,a=a*a%MOD)
            if(b&1) ret=ret*a%MOD;
        return ret;
    }
    void FFT(ll *a,int flag)
    {
        for(int i=0;i<lmt;i++)
            if(i<par[i]) swap(a[i],a[par[i]]);
        for(int len=1;len<lmt;len<<=1)
        {
            ll unit=quick_pow(3,(MOD-1)/(len<<1));
            if(flag==-1) unit=quick_pow(unit,MOD-2);
            for(int st=0;st<lmt;st+=(len<<1))
            {
                ll w=1;
                for(int k=st;k<st+len;k++,w=w*unit%MOD)
                {
                    ll A=a[k],B=w*a[k+len]%MOD;
                    a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD;
                }
            }
        }
    }
    
    int main()
    {
        scanf("%lld%lld",&n,&m);
        for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
        for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
        while(lmt<=n+m) lmt<<=1,dgt++;
        for(int i=0;i<lmt;i++)
            par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));
        
        FFT(a,1);FFT(b,1);
        for(int i=0;i<lmt;i++)
            (a[i]*=b[i])%=MOD;
        FFT(a,-1);
        ll inv=quick_pow(lmt,MOD-2);
        for(int i=0;i<=n+m;i++)
            printf("%lld ",a[i]*inv%MOD);
        return 0;
    }
    NTT

    3、MTT

    如果答案需要取模且模数非质数该如何处理呢?

    常见背景为:多项式长度$1e5$,模数$1e9$非质数,此时$FFT$爆$longlong$,没法用$NTT$

    (1)三模数$NTT$

    根据上方的数据限制,可发现最终答案最多为$1e23$

    这样就能用多个乘积大于$1e23$的模数分别做$NTT$最后再用$CRT$合并答案即可

    一般常用:469762049,998244353,1004535809

    可如果直接用$CRT$合并会发现模数爆$longlong$还是不好处理

    此时就可以先合并前两个式子,得到

    $res=k(mod(p_1*p_2)),res=a_3(mod(p_3))$

    这样设$res=p_1*p_2*c+k$再带入二式就能得到$c=(a_3-k)*(p_1*p_2)^{-1}(mod(p_3))$

    这样类似$exCRT$的分步处理就避开了对$p_1*p_2*p_3$的取模

    但这样要进行9次$DFT/IDFT$,常数巨大无比

    模板:

    #include <bits/stdc++.h>
    
    using namespace std;
    #define X first
    #define Y second
    #define pb push_back
    typedef double db;
    typedef long long ll;
    typedef pair<int,int> P;
    const int MAXN=4e5+10;
    ll p[]={469762049,998244353,1004535809};
    int n,m,MOD,F[MAXN],G[MAXN],dgt,lmt=1;
    ll a[3][MAXN],b[MAXN],res[MAXN],par[MAXN];
    
    ll quickpow(ll a,ll b,ll MOD)
    {
        a%=MOD;ll ret=1;
        for(;b;b>>=1,a=a*a%MOD)
            if(b&1) ret=ret*a%MOD;
        return ret;
    }
    ll mul(ll a,ll b,ll MOD)
    {
        a=(a%MOD+MOD)%MOD;
        b=(b%MOD+MOD)%MOD;ll ret=0;
        for(;b;b>>=1,a=(a+a)%MOD)
            if(b&1) (ret+=a)%=MOD;
        return ret;
    }
    ll inv(ll a,ll MOD)
    {return quickpow(a,MOD-2,MOD);}
    void FFT(ll *a,int flag,ll MOD)
    {
        for(int i=0;i<lmt;i++)
            if(i<par[i]) swap(a[i],a[par[i]]);
        for(int len=1;len<lmt;len<<=1)
        {
            ll unit=quickpow(3,(MOD-1)/(len<<1),MOD);
            if(flag==-1) unit=inv(unit,MOD);
            for(int st=0;st<lmt;st+=(len<<1))
            {
                ll w=1;
                for(int k=st;k<st+len;k++,w=w*unit%MOD)
                {
                    ll A=a[k],B=w*a[k+len]%MOD;
                    a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD;
                }
            }
        }
        if(flag==-1)
        {
            ll INV=inv(lmt,MOD);
            for(int i=0;i<lmt;i++)
                a[i]=a[i]*INV%MOD;
        }
    }
    void solve(ll *a,ll *b,ll MOD)
    {
        for(int i=0;i<=n;i++) a[i]=F[i];
        for(int i=0;i<=m;i++) b[i]=G[i];
        for(int i=m+1;i<lmt;i++) b[i]=0;
        FFT(a,1,MOD);FFT(b,1,MOD);
        for(int i=0;i<lmt;i++) a[i]=a[i]*b[i]%MOD;
        FFT(a,-1,MOD);
    }
    
    int main()
    {
        scanf("%d%d%d",&n,&m,&MOD);
        for(int i=0;i<=n;i++) scanf("%d",&F[i]);
        for(int i=0;i<=m;i++) scanf("%d",&G[i]);
        while(lmt<=n+m) lmt<<=1,dgt++;
        for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));
        
        for(int i=0;i<3;i++) solve(a[i],b,p[i]);
        for(int i=0;i<=n+m;i++)
        {
            ll M=p[0]*p[1];
            ll A=(mul(a[0][i]*p[1],inv(p[1],p[0]),M)+
                 mul(a[1][i]*p[0],inv(p[0],p[1]),M))%M;
            ll K=mul(a[2][i]-A,inv(M,p[2]),p[2]);
            res[i]=(mul(K,M,MOD)+A%MOD)%MOD;
        }
        for(int i=0;i<=n+m;i++)
            printf("%lld ",res[i]);
        return 0;
    }
    三模数NTT

    (2)拆系数$FFT$

     不能用$FFT$仅仅因为最后答案会爆$longlong$,那么可以将原数拆分后分别计算

    $A_i=a_i* sqrt{P}+b_i,B_i=c_i* sqrt{P}+d_i$

    此时$A*B=P*(a*c)+sqrt{P}*(a*d+b*c)+(b*d)$,每部分最大值为$1e14$,分别$DFT/IDFT$

    这样要做7次$DFT/IDFT$,效率未显著提升

    $myy$在论文里提到过对FFT的优化:

    设$P_j=A_j+i*B_j,Q_j=A_j-i*B_j$,使得$DFT$前虚部不再为空

    可推出$DFT$后的$DP,DQ$数组如下结论:

    $DP_k=sum_{j=0}^{lmt-1} (A_j+i*B_j)*w_{lmt}^{j*k},DQ_k=conj(DP_{lmt-k})$

    这样就能用1次对$P$的$DFT$算出$P,Q,A,B$的$DFT$,从而将上面的4次$DFT$化为2次

    由于$IDFT$就能看成$DFT$的逆过程

    因此可以合并算出$IDFT(DFT[a]*DFT[c]+i*DFT[b]*DFT[d])$,从而将$IDFT$也化为2次

    这样的常数经测试是三模数$NTT$的1/7左右

    模板:

    #include <bits/stdc++.h>
    
    using namespace std;
    #define X first
    #define Y second
    #define pb push_back
    typedef double db;
    typedef long long ll;
    typedef pair<int,int> P;
    const int MAXN=1e6+10;
    struct Complex
    {
        db x,y;
        Complex(db a=0,db b=0){x=a;y=b;}
        Complex operator +(const Complex& rhs)
        {return Complex(x+rhs.x,y+rhs.y);}
        Complex operator -(const Complex& rhs)
        {return Complex(x-rhs.x,y-rhs.y);}
        Complex operator *(const Complex& rhs)
        {return Complex(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);}
    }a[MAXN],b[MAXN],w[MAXN],t1[MAXN],t2[MAXN],t3[MAXN];
    int n,m,MOD,lmt=1,dgt,par[MAXN];ll x,res[MAXN];
    
    void FFT(Complex *a,int flag)
    {
        for(int i=0;i<lmt;i++)
            if(i<par[i]) swap(a[i],a[par[i]]);
        for(int len=1;len<lmt;len<<=1)
            for(int st=0;st<lmt;st+=(len<<1))
            {
                int cur=0;
                for(int k=st;k<st+len;k++)
                {
                    Complex A=a[k],B=w[cur]*a[k+len];
                    a[k]=A+B;a[k+len]=A-B;
                    //预处理的写法 
                    cur=(cur+flag*lmt/(len<<1)+lmt)&(lmt-1);
                }
            }
        if(flag==-1)
            for(int i=0;i<lmt;i++)
                a[i].x=floor(a[i].x/lmt+0.5);
    }
    void solve()
    {
        FFT(a,1);FFT(b,1);
        for(int i=0;i<lmt;i++)
        {
            Complex d1,d2,d3,d4;
            int j=(lmt-i)&(lmt-1);
            d1=(a[i]+Complex(a[j].x,-a[j].y))*Complex(0.5,0);
            d2=(a[i]-Complex(a[j].x,-a[j].y))*Complex(0,-0.5);
            d3=(b[i]+Complex(b[j].x,-b[j].y))*Complex(0.5,0);
            d4=(b[i]-Complex(b[j].x,-b[j].y))*Complex(0,-0.5);
            //必须先用临时变量存,因为后面还要用 
            t1[i]=d1*d3;t2[i]=d1*d4+d2*d3;t3[i]=d2*d4;
        }
        for(int i=0;i<lmt;i++)
            //充分利用虚部空间(可看成逆过程) 
            b[i]=t2[i],a[i]=t1[i]+t3[i]*Complex(0,1);
        FFT(a,-1);FFT(b,-1);
        for(int i=0;i<lmt;i++)
        {
            ll k1=(ll)a[i].x%MOD,k2=(ll)b[i].x%MOD;
            ll k3=(ll)floor(a[i].y/lmt+0.5)%MOD;
            res[i]=((k3<<30)%MOD+(k2<<15)%MOD+k1)%MOD;
        }
    }
    
    int main()
    {
        scanf("%d%d%d",&n,&m,&MOD);
        for(int i=0;i<=n;i++)
            scanf("%lld",&x),a[i]=Complex(x&32767,x>>15);
        for(int i=0;i<=m;i++)
            scanf("%lld",&x),b[i]=Complex(x&32767,x>>15);
        while(lmt<=n+m) lmt<<=1,dgt++;
        for(int i=0;i<lmt;i++)
            par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));
        for(int i=0;i<lmt;i++)
            w[i]=Complex(cos(2*M_PI*i/lmt),sin(2*M_PI*i/lmt));
        
        solve();
        for(int i=0;i<=n+m;i++)
            printf("%lld ",res[i]);
        return 0;
    }
    拆系数FFT
  • 相关阅读:
    简明Vim练级攻略
    linux之cat命令
    linux之cat,more,less,head,tail
    linux之touch命令修改文件的时间戳
    linux 之创建文件命令
    python开发_function annotations
    python开发_python中的range()函数
    python开发_python中的module
    python开发_python中的函数定义
    python开发_python关键字
  • 原文地址:https://www.cnblogs.com/newera/p/10076871.html
Copyright © 2020-2023  润新知