Description
模板题啦~
关于傅里叶变换,我有一篇博客做了详细介绍傅里叶变换 - Fourier Transform。
Code
//多项式乘法 - FFT
#include <cmath>
#include <complex>
#include <cstdio>
using namespace std;
typedef complex<double> cpx;
int const N=3e5+10;
double const PI=acos(-1);
int n,m,t; int pos[N];
cpx a[N],b[N],c[N];
void FFT(cpx *x,int f)
{
for(int i=0;i<t;i++) if(i<pos[i]) swap(x[i],x[pos[i]]);
for(int i=1;i<t;i<<=1)
{
cpx Wn=cpx(cos(PI/i),f*sin(PI/i));
for(int j=0;j<t;j+=i+i)
{
cpx w=cpx(1,0);
for(int k=0;k<i;k++,w*=Wn)
{
cpx p=x[j+k],q=w*x[j+k+i];
x[j+k]=p+q,x[j+k+i]=p-q;
}
}
}
if(f==-1) for(int i=0;i<t;i++) x[i]/=t;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&a[i]);
for(int i=0;i<=m;i++) scanf("%lf",&b[i]);
t=1; int n0=0; while(t<=n+m) t<<=1,n0++;
for(int i=0;i<t;i++) pos[i]=(pos[i>>1]>>1)|((i&1)<<(n0-1));
FFT(a,1),FFT(b,1);
for(int i=0;i<t;i++) c[i]=a[i]*b[i];
FFT(c,-1);
for(int i=0;i<=n+m;i++) printf("%d ",(int)(c[i].real()+0.5));
puts("");
return 0;
}
//多项式乘法 - NTT
#include <algorithm>
#include <cstdio>
using namespace std;
typedef long long lint;
inline char gc()
{
return getchar();
static char now[1<<16],*S,*T;
if(S==T) {T=(S=now)+fread(now,1,1<<16,stdin); if(S==T) return EOF;}
return *S++;
}
inline int read()
{
int x=0; char ch=gc();
while(ch<'0'||'9'<ch) ch=gc();
while('0'<=ch&&ch<='9') x=x*10+ch-'0',ch=gc();
return x;
}
int const N=(1<<20)+10;
lint const P=998244353;
int n,m,t;
int pos[N];
lint a[N],b[N],c[N];
lint pow(lint x,int y)
{
lint res=1,t=x;
while(y) {if(y&1) res=res*t%P; t=t*t%P,y>>=1;}
return res;
}
void NTT(lint x[],int f)
{
for(int i=0;i<t;i++) if(i<pos[i]) swap(x[i],x[pos[i]]);
for(int i=1;i<t;i<<=1)
{
lint Wn=pow(3,(P-1)/i/2);
if(f==-1) Wn=pow(Wn,P-2);
for(int j=0;j<t;j+=i+i)
{
lint w=1;
for(int k=0;k<i;k++,w=w*Wn%P)
{
lint p=x[j+k],q=w*x[j+k+i]%P;
x[j+k]=(p+q)%P,x[j+k+i]=(p-q+P)%P;
}
}
}
lint iT=pow(t,P-2);
if(f==-1) for(int i=0;i<t;i++) x[i]=x[i]*iT%P;
}
int main()
{
n=read(),m=read();
for(int i=0;i<=n;i++) a[i]=read();
for(int i=0;i<=m;i++) b[i]=read();
t=1; int n0=0; while(t<=n+m) t<<=1,n0++;
for(int i=0;i<t;i++) pos[i]=(pos[i>>1]>>1)|((i&1)<<(n0-1));
NTT(a,1),NTT(b,1);
for(int i=0;i<t;i++) c[i]=a[i]*b[i]%P;
NTT(c,-1);
for(int i=0;i<=n+m;i++) printf("%lld ",c[i]); puts("");
return 0;
}