题目背景
这是一道 FFT 模板题
题目描述
给定一个 n 次多项式 F(x),和一个 m 次多项式 G(x)。
请求出 F(x) 和 G(x) 的卷积。
输入格式
第一行 2 个正整数 n,m。
接下来一行 n+1 个数字,从低到高表示 F(x) 的系数。
接下来一行 m+1 个数字,从低到高表示 G(x) 的系数。
输出格式
一行 n+m+1 个数字,从低到高表示 F(x)*G(x) 的系数。
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const double Pi=acos(-1);
#define db double
#define maxn 1350000
inline int read(){
register char ch=0;
while(ch<48||ch>57)ch=getchar();
return ch-'0';
}
int n,m;
struct CP{
CP (db xx=0,db yy=0){x=xx;y=yy;}
db x,y;
CP operator + (CP const &B)const
{return CP(x+B.x,y+B.y);}
CP operator - (CP const &B)const
{return CP(x-B.x,y-B.y);}
CP operator * (CP const &B)const
{return CP(x*B.x-y*B.y,x*B.y+y*B.x);}
}f[maxn<<1];
int tr[maxn<<1];
inline void fft(CP *f,bool flag){
for(int i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
for(int p=2;p<=n;p<<=1){
int len=p>>1;
CP tG(cos(2*Pi/p),sin(2*Pi/p));
if(!flag)tG.y*=-1;
for(int k=0;k<n;k+=p){
CP buf(1,0);
for(int l=k;l<k+len;l++){
CP tt=buf*f[len+l];
f[len+l]=f[l]-tt;
f[l]=f[l]+tt;
buf=buf*tG;
}
}
}
}
signed main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)f[i].x=read();
for(int i=0;i<=m;i++)f[i].y=read();
for(m+=n,n=1;n<=m;n<<=1);
for(int i=0;i<n;i++)
tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
fft(f,1);
for(int i=0;i<n;i++)f[i]=f[i]*f[i];
fft(f,0);
for(int i=0;i<=m;i++)
printf("%d ",(int)(f[i].y/n/2+0.49));
return 0;
}
hzwer的代码:
#include<bits/stdc++.h>
#define N 262145
#define pi acos(-1)
using namespace std;
typedef complex<double> E;
int n,m,L;
int R[N];
E a[N],b[N];
void fft(E *a,int f)
{
for(int i=0;i<n;i++)if(i<R[i])swap(a[i],a[R[i]]);
for(int i=1;i<n;i<<=1)
{
E wn(cos(pi/i),f*sin(pi/i));
for(int p=i<<1,j=0;j<n;j+=p)
{
E w(1,0);
for(int k=0;k<i;k++,w*=wn)
{
E x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=0,x;i<=n;i++)scanf("%d",&x),a[i]=x;
for(int i=0,x;i<=m;i++)scanf("%d",&x),b[i]=x;
m=n+m;for(n=1;n<=m;n<<=1)L++;
for(int i=0;i<n;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
fft(a,1);fft(b,1);
for(int i=0;i<=n;i++)a[i]=a[i]*b[i];
fft(a,-1);
for(int i=0;i<=m;i++)
printf("%d ",(int)(a[i].real()/n+0.5));
return 0;
}
NTT写法:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll ty() {
char ch = getchar(); ll x = 0, f = 1;
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return x * f;
}
const int _=4e6+10;
const ll P=998244353,G=3,Gx=332748118;
int N,M,r[_];
ll A[_],B[_];
inline ll ksm(ll a,ll b) {
ll ret=1;
for (;b;b>>=1){
if (b & 1)ret=ret*a%P;
a=a*a%P;
}
return ret;
}
inline void ntt(int lim, ll *a, int op) {
for (int i=0;i<lim;++i)if(i<r[i])swap(a[i],a[r[i]]);
for (int len=2;len<=lim;len<<=1){
int mid=len >> 1;
ll Wn=ksm(op==1?G:Gx,(P-1)/len);
for (int i=0;i<lim;i+=len) {
ll w=1;
for (int j=0;j<mid;++j,w=(w*Wn)%P){
ll x=a[i+j],y=w*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
}
int main() {
N=ty(),M=ty();
for(int i = 0; i <= N; ++i) A[i]=(ty() + P) % P;
for(int i = 0; i <= M; ++i) B[i]=(ty() + P) % P;
int lim = 1, k = 0;
while (lim <= N + M) lim <<= 1, ++k;
for (int i = 0; i < lim ; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (k - 1));
ntt(lim, A, 1);
ntt(lim, B, 1);
for (int i = 0; i < lim; ++i) A[i] = (A[i] * B[i]) % P;
ntt(lim, A, -1);
ll invx = ksm(lim, P - 2);
for (int i = 0; i <= N + M; ++i)
printf("%lld ", (A[i] * invx) % P);
return 0;
}