• 洛谷 4245 【模板】任意模数NTT——三模数NTT / 拆系数FFT


    题目:https://www.luogu.org/problemnew/show/P4245

    三模数NTT:

      大概是用3个模数分别做一遍,用中国剩余定理合并。

      前两个合并起来变成一个 long long 的模数,再要和第三个合并的话就爆 long long ,所以可以用一种让两个模数的乘积不出现的方法:https://blog.csdn.net/qq_35950004/article/details/79477797

       x*m1+a1 = -y*m2 + a2  <==>  x*m1+y*m2 = a2-a1  <==>  x*m1 = a2-a1 (mod m2)  <==> x=(a2-a1)*m1^{-1} (mod m2)

      然后根据该博客里的证明,在mod m2意义下算出来的 x 就是真的 x 。这样的话答案就是 x*m1+a1 ,可以在快速乘的过程中对题目中给的模数取模,就不会爆 long long 啦。

      注意输入的 a[ ] 和 b[ ] 不能 ntt( ,0, ) 之后再 ntt( ,1, ) 回来,因为值已经模了刚才那个模数了;所以要多开一些数组。

      注意输入进 mul 里的 a 和 b 应该是正的,不然没法 b>>=1 之类的。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #define ll long long
    using namespace std;
    const int N=1e5+5;
    int m[3]={998244353,1004535809,469762049};
    int n0,n1,mod,len,r[N<<2],a[3][N<<2],b[3][N<<2],c[3][N<<2];
    ll M=(ll)m[0]*m[1],d[N<<1];
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    void upd(ll &x,ll md){x>=md?x-=md:0;}
    void upd(int &x,ll md){x>=md?x-=md:0;}
    ll mul(ll a,ll b,ll md)
    {
      a%=md; b%=md;//
      ll ret=0;while(b){if(b&1ll)ret+=a,upd(ret,md);a+=a;upd(a,md);b>>=1ll;}return ret;
    }
    ll pw(ll x,ll k,ll md)
    {ll ret=1;while(k){if(k&1ll)ret=mul(ret,x,md);x=mul(x,x,md);k>>=1ll;}return ret;}
    void ntt(int *a,bool fx,int md)
    {
      for(int i=0;i<len;i++)
        if(i<r[i])swap(a[i],a[r[i]]);
      for(int R=2;R<=len;R<<=1)
        {
          int m=R>>1;
          int Wn=pw(3,(md-1)/R,md);
          fx?Wn=pw(Wn,md-2,md):0;
          for(int i=0;i<len;i+=R)
        for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%md)
          {
            int tmp=(ll)w*a[i+m+j]%md;
            a[i+m+j]=a[i+j]+md-tmp; upd(a[i+m+j],md);
            a[i+j]=a[i+j]+tmp; upd(a[i+j],md);
          }
        }
      if(!fx)return; int inv=pw(len,md-2,md);
      for(int i=0;i<len;i++) a[i]=(ll)a[i]*inv%md;
    }
    int main()
    {
      n0=rdn()+1; n1=rdn()+1; mod=rdn();
      for(int i=0;i<n0;i++)a[0][i]=a[1][i]=a[2][i]=rdn();
      for(int i=0;i<n1;i++)b[0][i]=b[1][i]=b[2][i]=rdn();
      for(len=1;len<=n0+n1;len<<=1);
      for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0);
      for(int i=0;i<3;i++)//don't ntt(a,1,m[i]) for it can't return(already mod)
        {
          ntt(a[i],0,m[i]); ntt(b[i],0,m[i]);
          for(int j=0;j<len;j++)c[i][j]=(ll)a[i][j]*b[i][j]%m[i];
          ntt(c[i],1,m[i]);
        }
    
      ll inv=pw(m[0],m[1]-2,m[1]),t;
      for(int i=0,lm=n0+n1-1;i<lm;i++)
        {
          t=mul((c[1][i]-c[0][i])%m[1]+m[1],inv,m[1]);
          d[i]=(mul(t,m[0],M)+c[0][i])%M;
        }
      inv=pw(M,m[2]-2,m[2]);
      for(int i=0,lm=n0+n1-1;i<lm;i++)
        {
          t=mul((c[2][i]-d[i])%m[2]+m[2],inv,m[2]);
          d[i]=(mul(t,M,mod)+d[i])%mod;
          printf("%lld ",d[i]);
        }
      puts(""); return 0;
    }
    View Code

    拆系数FFT:

      参考材料:https://blog.csdn.net/lvzelong2014/article/details/80156989

      因为数值最大可能 10^(9+9+5) ,double 的精度可能很不好。所以把 a[ i ] 拆成 k*m + b ,其中 m 大约是 sqrt(mod) ;

      这样的话2个多项式变成了4个多项式,卷积的时候只卷积 k 和 b ,不用带上 m ,数值最大就是 sqrt(mod) 级别的,精度就能行了。

      但 ( k1 + b1 )*( k2 + b2 ) = k1*k2 + b1*k2 + k1*b2 + b1*b2 ,需要4次卷积,做8次DFT(似乎可以弄成7次),太慢了。

      所以就像那个博客里说的那样:

      令 ( P(x) = A(x)+i*B(x) , Q(x) = A(x)-i*B(x) ) (写的时候就是把 a[ i ] 放在 p[ i ] 的实部、把 b[ i ] 放在 p[ i ] 的虚部)

      则 ( Q(w_{n}^{k}) = conj(P(w_{n}^{-k})) )(这里已经代入了点值。即把 P DFT之后,可以根据它得到 Q 的点值)(证明见那个博客)

      然后由 P(x) 和 Q(x) 的定义,得

      ( A(w_{n}^{k}) = frac{P(w_{n}^{k})+Q(w_{n}^{k})}{2} )(点值)

      ( i*B(w_{n}^{k}) = frac{P(w_{n}^{k})-Q(w_{n}^{k})}{2} ) 即 ( B(w_{n}^{k}) = i*frac{Q(w_{n}^{k})-P(w_{n}^{k})}{2} )

      把向量写成 ( 实部+ i * 虚部 ) 的样子,就能把 i 代进去,从而得到 A(x) 和 B(x) 的点值向量。这种方法利用 P(x) 和 Q(x) 的关联,只 DFT 了一下P(x),就求出了A(x)和B(x)两个多项式的点值!

      所以 DFT 两次,求出 k1、b1、k2、b2 的点值。相乘一番,得到 k1*k2 、b1*k2 、k1*b2 、b1*b2 的点值。

      考虑 iDFT 回去。还是设 ( P(w_{n}^{k}) = A(w_{n}^{k}) + i*B(w_{n}^{k}))(点值),写的时候还是把向量写开,就能把 i 代进去,得到 ( P(w_{n}^{k}) ) 的向量。

      这时的 ( P(w_{n}^{k}) ) 是一组点值, iDFT 回去后直接是系数的 ( A(x) + i*B(x) ) !也就是 iDFT 后 p[ i ] 的实部是 A(x) 的系数 a[ i ] ,虚部是 b[ i ] 。

        这个的证明可以看那个博客,大概就是如果系数是( A 的系数 + B 的系数 )的话,点值也应该是 ( A 的点值 + B 的点值 );倒着推一下得到点值是 ( A 的点值 + B 的点值 ) 的话,iDFT 回去的系数也应该是 ( A 的系数 + B 的系数 ) 。

      这里的A(x)和B(x)就是 k1*k2 、b1*k2 、k1*b2 、b1*b2 中的两个。两次 iDFT 后就得到了这4个东西的系数表达!

      然后统计答案就行了。注意乘上 m2 或者 m 。

      这道题得开 long double 。

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cmath>
    #define db long double
    #define ll long long
    using namespace std;
    const int N=1e5+5,M=(1<<18)+5; const db pi=acos(-1);
    int mod,a[N],b[N],len,r[M],ans[M];
    struct cpl{db x,y;}pa[M],pb[M],Ta[M],Tb[M],Tc[M],Td[M],I;
    cpl operator+ (cpl a,cpl b){return (cpl){a.x+b.x,a.y+b.y};}
    cpl operator- (cpl a,cpl b){return (cpl){a.x-b.x,a.y-b.y};}
    cpl operator* (cpl a,cpl b){return (cpl){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
    cpl cnj(cpl a){return (cpl){a.x,-a.y};}
    void upd(int &x){x>=mod?x-=mod:0;}
    int rdn()
    {
      int ret=0;bool fx=1;char ch=getchar();
      while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
      while(ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
      return fx?ret:-ret;
    }
    void fft(cpl *a,bool fx)
    {
      for(int i=0;i<len;i++)
        if(i<r[i])swap(a[i],a[r[i]]);
      for(int R=2;R<=len;R<<=1)
        {
          int m=R>>1;
          cpl Wn=(cpl){ cos(pi/m),fx?-sin(pi/m):sin(pi/m) };
          for(int i=0;i<len;i+=R)
        {
          cpl w=I;
          for(int j=0;j<m;j++,w=w*Wn)
            {
              cpl x=a[i+j],y=w*a[i+m+j];
              a[i+j]=x+y;  a[i+m+j]=x-y;
            }
        }
        }
      if(!fx)return;
      for(int i=0;i<len;i++)a[i].x/=len,a[i].y/=len;//y!! for use y
    }
    void solve(int n,int *a,int m,int *b)
    {
      int bin=(1<<15)-1; cpl ta,tb,tc,td;
    
      for(int i=0;i<=n;i++)pa[i]=(cpl){a[i]>>15,a[i]&bin};//pa=A+iB(a=km+b,A=k,B=b)
      for(int i=0;i<=m;i++)pb[i]=(cpl){b[i]>>15,b[i]&bin};//<=m!!!
      for(len=1;len<=n+m;len<<=1);
      for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0);
      fft(pa,0);  fft(pb,0);
      pa[len]=pa[0]; pb[len]=pb[0];
      for(int i=0,j=len;i<len;i++,j--)//j=-i+len(if i=0 then j=0),qa[i]=cnj(pa[j])
        {
          //ta:point value of A
          ta=(pa[i]+cnj(pa[j]))*(cpl){0.5,0};//ta={(pa[i].x+pa[j].x)/2,(pa[i].y+pa[j].y)/2}
          tb=(pa[i]-cnj(pa[j]))*(cpl){0,-0.5};//tb=i(Q-P)/2=i*((Qx-Px)+i(Qy-Py))/2={(Py-Qy)/2,(Qx-Px)/2}
          tc=(pb[i]+cnj(pb[j]))*(cpl){0.5,0};
          td=(pb[i]-cnj(pb[j]))*(cpl){0,-0.5};
          Ta[i]=ta*tc;Tb[i]=ta*td;Tc[i]=tb*tc;Td[i]=tb*td;
        }
      pa[len]=pb[len]=(cpl){0,0};
      for(int i=0;i<len;i++)pa[i]=Ta[i]+Tb[i]*(cpl){0,1};//pa=Ta+i*Tb={Tax-Tby,Tay+Tbx}
      for(int i=0;i<len;i++)pb[i]=Tc[i]+Td[i]*(cpl){0,1};
      fft(pa,1);  fft(pb,1);//pa.x=Ta,pa.y=Tb
      for(int i=0,j=n+m;i<=j;i++)
        {
          int Da=(ll)(pa[i].x+0.5)%mod;
          int Db=(ll)(pa[i].y+0.5)%mod;
          int Dc=(ll)(pb[i].x+0.5)%mod;
          int Dd=(ll)(pb[i].y+0.5)%mod;
          ans[i]=(((ll)Da<<30) + ((ll)(Db+Dc)<<15) + Dd)%mod+mod; upd(ans[i]);
        }
    }
    int main()
    {
      int n,m; I.x=1;
      n=rdn();m=rdn();mod=rdn();
      for(int i=0;i<=n;i++)a[i]=rdn()%mod+mod,upd(a[i]);
      for(int i=0;i<=m;i++)b[i]=rdn()%mod+mod,upd(b[i]);
      solve(n,a,m,b);
      for(int i=0,j=n+m;i<=j;i++)printf("%d ",ans[i]);puts("");
      return 0;
    }
    View Code
  • 相关阅读:
    什么是交互式?
    python之禅
    爬虫保存cookies时重要的两个参数(ignore_discard和ignore_expires)的作用
    PL/0编译器(java version) – Symbol.java
    PL/0编译器(java version) – Scanner.java
    PL/0编译器(java version)–Praser.java
    PL/0编译器(java version)–PL0.java
    PL/0编译器(java version)–Pcode.java
    PL/0编译器(java version)
    PL/0编译器(java version)
  • 原文地址:https://www.cnblogs.com/Narh/p/10035325.html
Copyright © 2020-2023  润新知