• 【数学】快速数论变换(NTT)


    快速数论变换(NTT)

    这东西之前就想学了,一直没有动手 orz,现在补一下。

    学这东西我感觉并没有很多新知识,学之前掌握 FFT 就好了。

    FFT 可以在这里看看:https://www.cnblogs.com/Tenshi/p/15434004.html

    NTT,是用来解决多项式乘法取模问题的,因为 FFT 可能在精度上不够,而且常数较大,因此使用数论的手段将 FFT 进行改造得到 NTT。

    原理

    我们先假设模数 \(P\)​ 具有原根 \(rt\)​,这意味着 \(rt^{\varphi(P)}=1\)​,如果模数更特殊一点,假设模数为具有原根的质数(比如 \(998244353\))。那么则有 \(rt^{P-1} \equiv 1 \pmod P,~ rt^{P} \equiv rt \pmod P\)​​。

    而在 FFT 中,我们有 \((\cos\theta + i\sin\theta)^{2\pi/ \theta} = 1\),形式上和 \(rt^{P-1} \equiv 1 \pmod P\) 非常接近,我们再对模数进行限制:模数需要能表述为 \(k\times 2^x + 1\)​ 的形式。(因为这样才能保证在分治的过程中 \(P-1\) 能被 \(2\) 的幂次整除)。

    因此在取模意义下我们可以类似于 FFT 写出 NTT 的代码:

    void NTT(ll *a, int type, int mod){
    	for(int i=0; i<tot; i++){
    		a[i]%=mod;
    		if(i<rev[i]) swap(a[i], a[rev[i]]);
    	}
    	
    	for(int mid=1; mid<tot; mid<<=1){
    		ll w1=fpow(rt, (type==1? (mod-1)/(mid<<1): mod-1-(mod-1)/(mid<<1)), mod);
    		for(int i=0; i<tot; i+=mid*2){
    			ll wk=1;
    			for(int j=0; j<mid; j++, wk=wk*w1%mod){
    				auto x=a[i+j], y=wk*a[i+j+mid]%mod;
    				a[i+j]=(x+y)%mod, a[i+j+mid]=(x-y+mod)%mod;
    			}
    		}
    	}
    	
    	if(type==-1){
    		for(int i=0; i<tot; i++) a[i]=a[i]*inv(tot, mod)%mod;
    	}
    }
    

    拓展

    由上所述,我们在 NTT 过程中对模数进行了一些限制,自然不能够直接推广来解决任意模数 NTT 的问题,怎么办呢?

    结合模板题来说:

    https://www.luogu.com.cn/problem/P4245

    记题目所给的模数为 \(P\)

    我们可以选择三个模数 \(m_1,m_2,m_3\)​ 满足 \(m_1m_2m_3 > nP^2\)​。

    可以取为 \(m_1=998244353,~ m_2=1004535809,~ m_3=469762049\)。​

    我们先用这三个模数分别做一次 NTT,然后用 CRT(中国剩余定理)将这三个结果合并即可。

    直接合并会爆 long long,那怎么合并呢?具体来说:

    假设所求结果为 \(ans\),同时我们约定 \(inv(x,y)\) 代表 \(x\)\(y\)​ 的逆元。

    \[ans\equiv c_1 \pmod {m_1} \\ ans\equiv c_2 \pmod {m_2} \\ ans\equiv c_3 \pmod {m_3} \\ \]

    先对前两个进行合并:

    \(ans \equiv c_1\times m_2\times inv(m_1, m_2) + c_2\times m_1\times inv(m_2, m_1) \pmod {m_1m_2}\)

    \(M=m_1m_2, ~ C=c_1\times m_2\times inv(m_1, m_2) + c_2\times m_1\times inv(m_2, m_1)\)

    那么我们设 \(ans = xM + C = ym_3 + c_3\)​​。

    则有 \(x \equiv (c_3-C)\times inv(M, m_3) \pmod {m_3}\)​。

    \(t = (c_3-C)\times inv(M, m_3)\)​​,\(t\) 显然可以直接算出来,我们进一步设 \(x = km_3 + t\)

    \(ans = km_1m_2m_3 + tM + C\)​​。

    因为 \(ans < m_1 m_2 m_3\),因此 \(ans=tM+C\)

    最后对 \(P\) 取模即可。​

    实现:

    // Problem: P4245 【模板】任意模数多项式乘法
    // Contest: Luogu
    // URL: https://www.luogu.com.cn/problem/P4245
    // Memory Limit: 500 MB
    // Time Limit: 2000 ms
    // 
    // Powered by CP Editor (https://cpeditor.org)
    
    #include<bits/stdc++.h>
    using namespace std;
    
    #define debug(x) cerr << #x << ": " << (x) << endl
    #define rep(i,a,b) for(int i=(a);i<=(b);i++)
    #define dwn(i,a,b) for(int i=(a);i>=(b);i--)
    
    using pii = pair<int, int>;
    using ll = long long;
    
    inline void read(int &x){
        int s=0; x=1;
        char ch=getchar();
        while(ch<'0' || ch>'9') {if(ch=='-')x=-1;ch=getchar();}
        while(ch>='0' && ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
        x*=s;
    }
    
    const int N=3e5+5;
    const ll m1=998244353, m2=1004535809, m3=469762049, M=m1*m2, rt=3;
    
    int n, m, P;
    ll a[3][N], b[3][N], ans[N];
    
    int rev[N], tot=1, bit;
    
    ll fpow(ll x, int p, ll mod){
    	int res=1;
    	for(; p; p>>=1, x=x*x%mod) if(p&1) res=res*x%mod;
    	return res;
    }
    
    ll inv(ll x, ll mod){
    	return fpow(x, mod-2, mod);
    }
    
    ll mul(ll x, int p, ll mod){
    	ll res=0;
    	for(; p; p>>=1, x=(x+x)%mod) if(p&1) res=(res+x)%mod;
    	return res;
    }
    
    void NTT(ll *a, int type, int mod){
    	for(int i=0; i<tot; i++){
    		a[i]%=mod;
    		if(i<rev[i]) swap(a[i], a[rev[i]]);
    	}
    	
    	for(int mid=1; mid<tot; mid<<=1){
    		ll w1=fpow(rt, (type==1? (mod-1)/(mid<<1): mod-1-(mod-1)/(mid<<1)), mod);
    		for(int i=0; i<tot; i+=mid*2){
    			ll wk=1;
    			for(int j=0; j<mid; j++, wk=wk*w1%mod){
    				auto x=a[i+j], y=wk*a[i+j+mid]%mod;
    				a[i+j]=(x+y)%mod, a[i+j+mid]=(x-y+mod)%mod;
    			}
    		}
    	}
    	
    	if(type==-1){
    		for(int i=0; i<tot; i++) a[i]=a[i]*inv(tot, mod)%mod;
    	}
    }
    
    void CRT(){
    	for(int i=0; i<tot; i++){
    		ll res=0;
    		(res+=mul(a[0][i]*m2%M, inv(m2, m1), M))%=M;
    		(res+=mul(a[1][i]*m1%M, inv(m1, m2), M))%=M;
    		a[1][i]=res;
    	}
    	for(int i=0; i<tot; i++){
    		ll res=(a[2][i]-a[1][i]%m3+m3)%m3*inv(M%m3, m3)%m3;
    		ans[i]=(M%P*res%P+a[1][i]%P)%P;
    	}
    }
    
    void solve(int k, int mod){
    	NTT(a[k], 1, mod), NTT(b[k], 1, mod);
    	for(int i=0; i<tot; i++) a[k][i]=a[k][i]*b[k][i]%mod;
    	NTT(a[k], -1, mod);
    }
    
    int main(){
    	cin>>n>>m>>P;
    	rep(i,0,n){
    		int t; read(t);
    		rep(j,0,2) a[j][i]=t%P;
    	}
    	rep(i,0,m){
    		int t; read(t);
    		rep(j,0,2) b[j][i]=t%P;
    	}
    	
    	while(tot<=n+m) bit++, tot<<=1;
    	for(int i=0; i<tot; i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    	
    	solve(0, m1), solve(1, m2), solve(2, m3);
    	CRT();
    	
    	rep(i,0,n+m) cout<<ans[i]<<' ';
    	cout<<endl;
    	
    	return 0;
    }
    
  • 相关阅读:
    使用Shell脚本查找程序对应的进程ID,并杀死进程
    转,mysql快速保存插入大量数据一些方法总结
    L2TP/IPSec一键安装脚本
    全文搜索引擎 Elasticsearch 入门教程
    vmware设置扩大硬盘后如何在linux内容扩容
    Java序列化说明
    GIT常用命令
    java中的CAS
    Class.forName()用法详解
    Java用pdfbox或icepdf转换PDF为图片时,中文乱码问题
  • 原文地址:https://www.cnblogs.com/Tenshi/p/15642604.html
Copyright © 2020-2023  润新知