• FFT NTT 学习笔记


    显然因为我不会数学,所以这篇文章会非常“ 感性 ”。

    题目

    将两个多项式乘起来,即求 (f*g=h) 。多项式的项数 (nle10^5)

    FFT

    前置知识

    复数

    复数是指形如 (x+yi) 的数,高中会教。

    它的四则运算法则是这样的:(令 (p,q) 为两个复数)

    [ppm q=(x_ppm x_q)+(y_ppm y_q)i\ p imes q=(x_px_q-y_py_q)+(x_py_q+x_qy_p)i ]

    除法用不着。

    所以代码是这样的:(建议不要用 STL 的复数,常数巨大)

    struct mle{lod x,y;}a[n7],b[n7];
    mle operator + (mle p,mle q){return (mle){p.x+q.x,p.y+q.y};}
    mle operator - (mle p,mle q){return (mle){p.x-q.x,p.y-q.y};}
    mle operator * (mle p,mle q){return (mle){p.x*q.x-p.y*q.y,p.x*q.y+p.y*q.x};}
    

    多项式运算

    略。

    点值表达法

    选出 ((2n-1)) 个互不相同的横坐标 (x_i) ,代入 (f)(g) 中,得到很多个 (fy_i,gy_i),而 ((x_i,fy_i)) 就是 (f) 的点值表达式, ((x_i,gy_i)) 就是 (g) 的点值表达式。神奇的事实是, ((x_i,fy_i imes gy_i)) 就是 (h) 的点值表达式!

    所以 FFT 的思想就是,将 (f)(g) 转换成点值表达法,然后相乘得到 (h),最后再化为系数表示法(普通多项式的表示)

    其中,转化为点值表达法的步骤叫做 DFT,化回来的步骤叫做 IDFT

    单位根

    有一个神奇的东西叫做单位根(复数),满足 (w^n=1)(w) 被称作 (n) 次单位根。(事实上应该是 (omega),但是写起来太麻烦了就用 (w) 了)

    经过推导,如果将所有的单位根排列,第 (k)(n) 次单位根 (Large w_k=e^{ifrac{2kpi}{n}})

    (n) 是正偶数,且 (m)(n) 的一半,那么有 ((w_n^k)^2=w_m^k) 以及 (w_n^{m+k}=-w-n^k)。这两个等式就是法术,下文会用到。

    算法

    DFT(转成点值表示)

    首先为了方便,我们把多项式变为 (n=(2^k-1)) 次多项式(不足的补系数 (0)),而且它大于原本的 (f)(g) 的项数之和。这样原来的 ((2n-1)) 就不大于的 (n) ,方便统计。

    然后我们想知道点值表达,所以我们需要代入一些 (x),并求出 (f(x)) 的值。

    我们选择把魔幻的单位根(w_n^0,w_n^1...w_n^{n-1}) 代入。

    好,我们怎么求 (f(x))?先变变形式:

    [f(x)=a_0x^0+a_1x^1+a_2x^2+...+a_{n-1}x^{n-1}\ f(x)=(a_0x^0+a_2x^2+a_4x^4+...)+(a_1x^1+a_3x^3+a_5x^5+...)\ f(x)=(a_0x^0+a_2x^2+a_4x^4+...)+x(a_1x^0+a_3x^2+a_5x^4+...)\ f(x)=f_0(x^2)+xf_1(x^2) ]

    其中 (f,f_0,f_1) 并不是一样的,注意,(f_0) 的系数依次为 (a_0,a_2,a_4...)(f_1)(a_1,a_3,a_5...)

    运用着前文提到的法术,仍然设 (m=frac{n}{2}),对于 (k<m)

    [代入:f(w_n^k)=f_0((w_n^k)^2)+(w_n^k)f_1((w_n^k)^2)\ 等式一:f(w_n^k)=f_0(w_m^k)+w_n^kf_1(w_m^k) ]

    那么对于 (kge m) 呢?我们可以说它是 ((k+m))(k<m) 。继续用法术

    [代入:f(w_n^{m+k})=f_0((w_n^{m+k})^2)+(w_n^{m+k})f_1((w_n^{m+k})^2)\ 等式二:f(w_n^{m+k})=f_0((w_n^k)^2)+-w_n^kf_1((w_n^k)^2)\ 等式一:f(w_n^{m+k})=f_0(w_m^k)-w_n^kf_1(w_m^k) ]

    哇塞,一个是 (f_0+f_1),另一个是 (f_0-f_1)

    于是我们只要求出 (w_n^0sim w_n^m) 就可以知道剩下的了。而 (f_0,f_1) 可以递归求。

    IDFT(化回来)

    不会。但是代码和 DFT 是基本一样的。

    实现

    首先是直观但是常数大的递归版

    void FFT(mle *c,int len,bool sys){
        //part 1
    	if(len==1)return;
    	mle zuo[(len>>1)+1],you[(len>>1)+1];
    	for(int i=0;i<=len;i=i+2)zuo[i>>1]=c[i],you[i>>1]=c[i+1];
    	FFT(zuo,len>>1,sys),FFT(you,len>>1,sys);
    	
        //part 2
    	lod tnp=2.0*pie/len;int wal=len>>1;
    	mle ori=(mle){cos(tnp),(sys?1:-1)*sin(tnp)},z=(mle){1,0};
    	
        //part 3
    	rep(i,0,wal-1){
    		c[i]=zuo[i]+z*you[i];
    		c[i+wal]=zuo[i]-z*you[i];
    		z=z*ori;
    	}
    }
    

    PART:

    1. 把系数分为两个部分,然后依次递归。

    2. 计算。其中 (w_n^0=1,w_n^k=w_n^{k-1}*ori)

      其中 DFT(ori=cos{frac{2pi}{n}+sinfrac{2pi}{n}})IDFT(ori=cos{frac{2pi}{n}-sinfrac{2pi}{n}})。以及 (wal) 就是 (frac{n}{2})

    3. 做前文的事情。运用 (w_n^k=w_n^{k-1}*ori)

    注意,DFT 的时候把原本装系数的数组变成了现在装 (f(w)) 的数组。

    于是你就轻松有了 66分。毫无疑问,常数太大了!

    优化

    对于第 (x) 个系数 (a_x) ,它的路径是怎么样的?

    0 1 2 3 4 5 6 7
    0 2 4 6,1 3 5 7
    0 4,2 6,1 5,3 7
    0,4,2,6,1,5,3,7
    

    你会发现,

    如果把 0,4,2,6,1,5,3,7

    每一个数都转为二进制 000, 100, 010, 110, 001, 101, 011, 111

    再每一个二进制反过来 000, 001, 010, 011, 100, 101, 110, 111

    最后化为十进制 0,1,2,3,4,5,6,7。哦豁!

    所以可以快速求得递归的底层是怎么样的,然后我们模拟递归,枚举长度(1,2,4,8……),然后把一段长度的合并。

    但是又怎么样求二进制反转呢?

    (color{red}Huge 待填!)

    顺便提一句小优化,因为复数乘法常数大,所以一般在 (f=f_0pm f_1)是这样写:

    	rep(i,0,wal-1){
    		mle mul=z*you[i];
    		c[i]=zuo[i]+mul;
    		c[i+wal]=zuo[i]-mul;
    		z=z*ori;
    	}
    

    最后代码:

    也是待填



    FFT代码

    #include<bits/stdc++.h>
    #define rep(i,x,y) for(int i=x;i<=y;++i)
    #define lod double
    using namespace std;
    const int n7=3012345;
    const lod pie=acos(-1);
    int n,m,rv[n7];
    
    struct mle{lod x,y;}a[n7],b[n7];
    mle operator + (mle p,mle q){return (mle){p.x+q.x,p.y+q.y};}
    mle operator - (mle p,mle q){return (mle){p.x-q.x,p.y-q.y};}
    mle operator * (mle p,mle q){return (mle){p.x*q.x-p.y*q.y,p.x*q.y+p.y*q.x};}
    
    int rd(){
    	int shu=0;char ch=getchar();
    	while(!isdigit(ch))ch=getchar();
    	while(isdigit(ch))shu=(shu<<1)+(shu<<3)+ch-'0',ch=getchar();
    	return shu;
    }
    
    void FFT(mle *c,bool sys){
    	rep(i,0,n-1)if(i<rv[i])swap(c[i],c[ rv[i] ]);
    	for(int len=2;len<=n;len<<=1){
    		mle ori=(mle){cos(2*pie/len),(sys?1:-1)*sin(2*pie/len)};
    		int le=(len>>1);
    		for(int i=0;i<n;i+=len){
    			mle z=(mle){1,0};
    			rep(j,i,i+le-1){
    				mle tmp=z*c[le+j];
    				c[le+j]=c[j]-tmp;
    				c[j]=c[j]+tmp;
    				z=z*ori;
    			}
    		}
    	}
    }
    
    int main(){
    	n=rd(),m=rd();
    	rep(i,0,n)a[i].x=rd();
    	rep(i,0,m)b[i].x=rd();
    	m=m+n,n=1;
    	while(n<=m)n=n<<1;
    	rep(i,0,n-1){
    		rv[i]=(rv[i>>1]>>1);
    		if(i&1)rv[i]=rv[i]|(n>>1);
    	}
    	FFT(a,1),FFT(b,1);
    	rep(i,0,n)a[i]=a[i]*b[i];
    	FFT(a,0);
    	rep(i,0,m)printf("%d ",(int)(a[i].x/n+0.5));
    	return 0;
    }
    
    

    NTT代码

    #include<bits/stdc++.h>
    #define rep(i,x,y) for(int i=x;i<=y;++i)
    #define lon long long
    using namespace std;
    const int n7=3012345;const lon mo=998244353;
    int n,m,rv[n7];lon a[n7],b[n7];
    
    int rd(){
    	int shu=0;char ch=getchar();
    	while(!isdigit(ch))ch=getchar();
    	while(isdigit(ch))shu=(shu<<1)+(shu<<3)+ch-'0',ch=getchar();
    	return shu;
    }
    
    lon Dpow(lon p,lon q){
    	lon tot=1;
    	while(q){
    		if(q&1)tot=tot*p%mo;
    		p=p*p%mo,q=q>>1;
    	}
    	return tot;
    }
    
    void NTT(lon *c,bool sys){
    	rep(i,0,n-1)if(i<rv[i])swap(c[i],c[ rv[i] ]);
    	for(int len=2;len<=n;len<<=1){
    		lon ori=Dpow(sys?3:332748118,(mo-1)/len);
    		int le=(len>>1);
    		for(int i=0;i<n;i+=len){
    			lon z=1;
    			rep(j,i,i+le-1){
    				lon tmp=z*c[le+j]%mo;
    				c[le+j]=(c[j]-tmp+mo)%mo;
    				c[j]=(c[j]+tmp)%mo;
    				z=z*ori%mo;
    			}
    		}
    	}
    }
    
    int main(){
    	n=rd(),m=rd();
    	rep(i,0,n)a[i]=rd();
    	rep(i,0,m)b[i]=rd();
    	m=m+n,n=1;
    	while(n<=m)n=n<<1;
    	rep(i,0,n-1){
    		rv[i]=(rv[i>>1]>>1);
    		if(i&1)rv[i]=rv[i]|(n>>1);
    	}
    	NTT(a,1),NTT(b,1);
    	rep(i,0,n)a[i]=a[i]*b[i]%mo;
    	NTT(a,0);
    	lon inv=Dpow(n,mo-2);
    	rep(i,0,n)a[i]=a[i]*inv%mo;
    	rep(i,0,m)printf("%lld ",a[i]);
    	return 0;
    }
    
    
  • 相关阅读:
    HTTP请求下载文件格式
    MT7621 加 openWRT 用HTTP和远程服务器通信
    MT7621加 OPENWRT 移植MQTT(paho.mqtt.c) 进行数据的收发
    MT7621安装的openwrt出现无法删除文件的问题
    GAI_LIB = -lanl
    error: expected declaration specifiers or '...' before numeric constant void free(void *);
    environment variable 'STAGING_DIR' not defined
    ubuntu安装 make4.2
    gcc在root权限下查不到版本
    【原创】大叔经验分享(113)markdown语法
  • 原文地址:https://www.cnblogs.com/BlankAo/p/14333216.html
Copyright © 2020-2023  润新知