• 洛谷.4245.[模板]任意模数NTT(MTT/三模数NTT)


    题目链接


    三模数(NTT)

    就是多模数(NTT)最后(CRT)一下...下面两篇讲的都挺明白的。
    https://blog.csdn.net/kscla/article/details/79547242
    https://blog.csdn.net/zhouyuheng2003/article/details/85561887

    模数不是(NTT)模数,考虑用多个(NTT)模数分别卷积,最后(CRT)合并(由中国剩余定理,同余方程组在模(M=prod m_i)的情况下解是唯一的)。
    卷积后(a_i)的最大值是((10^9)^2 imes10^5=10^{23}),所以需要选几个乘积(>10^{23})的质数,比如(m_1=998244353=2^{23}*119+1,quad m_2=1004535809=2^{21}*479+1,quad m_3=469762049=2^{26}*7+1)。它们的原根都是(3)
    最后我们能求出三个同余方程:$$xequiv c_1 (mathbb{mod} m_1)xequiv c_2 (mathbb{mod} m_2)xequiv c_3 (mathbb{mod} m_3)$$

    我们可以用(CRT)合并两个,就有(xequiv c_4 (mathbb{mod} m_1m_2))
    (x=am_1m_2+c_4equiv c_3 (mathbb{mod} m_3)),就可以算出(aequivfrac{c_3-c_4}{m_1m_2} (mathbb{mod} m_3)),有了(a)代回去就算出(x)啦(虽然是模(m_3)意义下求的(a),但是设(a=km_3+b),代回到(x)里模(m_1m_2m_3)就把(km_3)那项消掉啦,只剩下(b))(当然前两个式子也可以这么合并)。
    这样一共需要(9)(DFT),常数巨大。


    拆系数(FFT)((MTT))

    (FFT)的问题在于精度,我们可以压缩值域。
    (m=lceilsqrt p ceil, A_i=a imes m+b, B_i=c imes m+d),那么(h_{i+j}=sum(a_im+b_i)(c_jm+d_j)=sum a_ic_jm^2+(a_id_j+b_ic_j)m+b_id_j),分别求出四项(中间两项放一块算),就只需要(7)(DFT)
    为了方便可以令(m=2^{15}=32768)
    这样(FFT)后的最大结果是(10^5 imes2^{30}=10^{14}),注意取模。

    注意std::sin的精度比sin高!因为值域依旧很大所以必须要注意这个,还要开long double
    还有个挺重要的优化是预处理单位根,预处理(lim)次单位根的(0sim lim-1)次方,用(omega_i^k)时就是(omega_{lim}^{k imesfrac{lim}{i}})(怎么都预处理的(2lim)次单位根啊...)。
    注意预处理的时候不要像(FFT)里那样每次w=w*Wn,而是对每个单位根直接用欧拉公式算,这样精度误差会小非常多,就可以把long double替换成double了...(然后就可以快很多很多)
    (嗯...误差问题都出在求单位根上了...)

    现在(7)(DFT)已经挺快了...还可以优化,可以看毛啸的论文,或者这里以及这里我咕了。


    三模数(NTT)

    //3786ms	8.95MB
    #include <cstdio>
    #include <cctype>
    #include <cstring>
    #include <algorithm>
    #define G 3
    #define MOD(x,mod) x>=mod&&(x-=mod)
    #define ADD(x,v,mod) (x+=v)>=mod&&(x-=mod)
    #define gc() getchar()
    #define MAXIN 500000
    //#define gc() (SS==TT&&(TT=(SS=IN)+fread(IN,1,MAXIN,stdin),SS==TT)?EOF:*SS++)
    typedef long long LL;
    const int N=(1<<18)+3,Mod[]={998244353,469762049,1004535809};
    
    int tA[N],tB[N],A[N],B[N],Ans[3][N],rev[N];
    char IN[MAXIN],*SS=IN,*TT=IN;
    
    inline int read()
    {
    	int now=0;register char c=gc();
    	for(;!isdigit(c);c=gc());
    	for(;isdigit(c);now=now*10+c-48,c=gc());
    	return now;
    }
    inline int FP(int x,int k,const int mod)
    {
    	int t=1;
    	for(; k; k>>=1,x=1ll*x*x%mod)
    		if(k&1) t=1ll*t*x%mod;
    	return t;
    }
    inline LL Mult(LL a,LL b,LL p)
    {
    	LL t=a*b-(LL)((long double)a/p*b+1e-8)*p;
    	return t<0?t+p:t;
    }
    void NTT(int *a,int lim,int opt,const int mod)
    {
    	const int invG=FP(G,mod-2,mod);
    	for(int i=1; i<lim; ++i) if(i<rev[i]) std::swap(a[i],a[rev[i]]);//&&
    	for(int i=2; i<=lim; i<<=1)
    	{
    		int mid=i>>1,Wn=FP(~opt?G:invG,(mod-1)/i,mod);
    		for(int j=0; j<lim; j+=i)
    			for(int t,w=1,k=j; k<j+mid; ++k,w=1ll*w*Wn%mod)
    				a[k+mid]=a[k]+mod-(t=1ll*a[k+mid]*w%mod), MOD(a[k+mid],mod),
    				ADD(a[k],t,mod);
    	}
    	if(opt==-1) for(int i=0,inv=FP(lim,mod-2,mod); i<lim; ++i) a[i]=1ll*a[i]*inv%mod;
    }
    void Solve(const int lim,int *ans,const int mod)
    {
    	memcpy(A,tA,lim<<2), memcpy(B,tB,lim<<2);//不能只赋值到n!(清空后面的)
    	NTT(A,lim,1,mod), NTT(B,lim,1,mod);
    	for(int i=0; i<lim; ++i) A[i]=1ll*A[i]*B[i]%mod;
    	NTT(A,lim,-1,mod);
    	for(int i=0; i<lim; ++i) ans[i]=A[i];
    }
    
    int main()
    {
    	const int n=read(),m=read(),P=read();
    	for(int i=0; i<=n; ++i) tA[i]=read(),tA[i]>=P&&(tA[i]%=P);
    	for(int i=0; i<=m; ++i) tB[i]=read(),tB[i]>=P&&(tB[i]%=P);
    	int lim=1,l=-1; while(lim<=n+m) lim<<=1,++l;
    	for(int i=1; i<lim; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
    	for(int i=0; i<3; ++i) Solve(lim,Ans[i],Mod[i]);
    	#define m1 998244353ll
    	#define m2 469762049ll
    	#define m3 1004535809ll
    	const LL M12=m1*m2,inv2=FP(m2,m1-2,m1),inv1=FP(m1,m2-2,m2),mul2=m2*inv2%M12,mul1=m1*inv1%M12,
    			inv=FP(M12%m3,m3-2,m3),m12=M12%P;
    	for(int i=0,t=n+m; i<=t; ++i)
    	{
    		LL c1=Ans[0][i],c2=Ans[1][i],c3=Ans[2][i],c4=Mult(c1,mul2,M12)+Mult(c2,mul1,M12);
    		MOD(c4,M12);
    		LL a=((c3-c4)%m3+m3)/*%m3*/*inv%m3;
    		printf("%d ",(int)((a*m12+c4)%P));
    	}
    
    	return 0;
    }
    

    (MTT)

    //1094ms	26.14MB
    #include <cmath>
    #include <cstdio>
    #include <cctype>
    #include <algorithm>
    #define gc() getchar()
    #define MAXIN 300000
    //#define gc() (SS==TT&&(TT=(SS=IN)+fread(IN,1,MAXIN,stdin),SS==TT)?EOF:*SS++)
    typedef long long LL;
    typedef long double ld;
    //#define double long double
    const int N=(1<<18)+3;
    const double PI=acos(-1);
    
    int rev[N];
    char IN[MAXIN],*SS=IN,*TT=IN;
    struct Complex
    {
    	double x,y;
    	Complex(double x=0,double y=0):x(x),y(y) {}
    	inline Complex operator +(const Complex &a)const {return Complex(x+a.x, y+a.y);}
    	inline Complex operator -(const Complex &a)const {return Complex(x-a.x, y-a.y);}
    	inline Complex operator *(const Complex &a)const {return Complex(x*a.x-y*a.y, y*a.x+x*a.y);}
    }A[N],B[N],C[N],D[N],W[N];
    
    inline int read()
    {
    	int now=0; register char c=gc();
    	for(;!isdigit(c);c=gc());
    	for(;isdigit(c);now=now*10+c-'0',c=gc());
    	return now;
    }
    void FFT(Complex *a,const int lim,const int opt)
    {
    	static Complex w[N];
    	for(int i=1; i<lim; ++i) if(i<rev[i]) std::swap(a[i],a[rev[i]]);
    	for(int i=2; i<=lim; i<<=1)
    	{
    		int mid=i>>1;
    		if(~opt) for(int k=0,t=lim/i; k<mid; ++k) w[k]=W[k*t];
    		else for(int k=0,t=lim/i; k<mid; ++k) w[k]=Complex(W[k*t].x,-W[k*t].y);
    		for(int j=0; j<lim; j+=i)
    		{
    			Complex t;
    			for(int k=j; k<j+mid; ++k)
    				a[k+mid]=a[k]-(t=w[k-j]*a[k+mid]), a[k]=a[k]+t;
    		}
    	}
    	if(opt==-1) for(int i=0; i<lim; ++i) a[i].x/=lim;
    }
    
    int main()
    {
    	const int n=read(),m=read(),P=read();
    	for(int i=0,t; i<=n; ++i) t=read(),A[i].x=t>>15,B[i].x=t&32767;
    	for(int i=0,t; i<=m; ++i) t=read(),C[i].x=t>>15,D[i].x=t&32767;
    
    	int lim=1,l=-1; while(lim<=n+m) lim<<=1,++l;
    	for(int i=0; i<lim; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
    	for(int i=0,t=lim>>1; i<lim; ++i) W[i]=Complex(std::cos(i*PI/t),std::sin(i*PI/t));
    //	const Complex Wn(std::cos(2.0*PI/lim),std::sin(2.0*PI/lim)); Complex w(1,0);
    //	for(int i=0; i<lim; ++i) W[i]=w, w=w*Wn;//精度巨差 
    
    	FFT(A,lim,1), FFT(B,lim,1), FFT(C,lim,1), FFT(D,lim,1);
    	for(int i=0; i<lim; ++i)
    	{
    		const Complex a=A[i],b=B[i],c=C[i],d=D[i];
    		A[i]=a*c, B[i]=a*d+b*c, D[i]=b*d;
    	}
    	FFT(A,lim,-1), FFT(B,lim,-1), FFT(D,lim,-1);
    	for(int i=0,t=n+m; i<=t; ++i)
    	{
    		LL res=((LL(A[i].x+0.5))%P<<30)+((LL(B[i].x+0.5))%P<<15)+(LL(D[i].x+0.5));
    		printf("%d ",(int)(res%P));
    	}
    
    	return 0;
    }
    
  • 相关阅读:
    面向对象编程——设计模式之一
    mysql死锁——mysql之四
    Mysql基本类型(字符串类型)——mysql之二
    Mysql基本类型(五种年日期时间类型)——mysql之二
    Mysql基础教程——mysql之一
    JVM启动参数手册——JVM之八
    Thinkphp 框架2
    Thinkphp 框架
    流程(下)
    流程(上)
  • 原文地址:https://www.cnblogs.com/SovietPower/p/10546764.html
Copyright © 2020-2023  润新知