题目
题目链接:https://www.luogu.com.cn/problem/P3803
给定一个 \(n\) 次多项式 \(F(x)\),和一个 \(m\) 次多项式 \(G(x)\)。
请求出 \(F(x)\) 和 \(G(x)\) 的卷积。
思路
本来以为这个数论菜比只能背板的。。。结果差不多看懂了?
强烈推荐: Blog1 Blog2
其实就是菜不想码公式而已 /fad。
\(\operatorname{Update\ on\ 2020.09.17:}\) 增加了 NTT 实现代码。
代码
#include <bits/stdc++.h>
#define cp complex<double>
using namespace std;
const int N=3000010;
const double pi=acos(-1);
int n,m,Maxn=1,rev[N];
cp f[N],g[N];
int read()
{
int d=0; char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
return d;
}
void fft(cp *f,int tag)
{
for (int i=0;i<Maxn;i++)
if (i<rev[i]) swap(f[i],f[rev[i]]);
for (int mid=1;mid<Maxn;mid<<=1)
{
cp temp(cos(pi/mid),tag*sin(pi/mid));
for (int i=0;i<Maxn;i+=(mid<<1))
{
cp w(1,0);
for (int j=0;j<mid;j++,w*=temp)
{
cp x=f[i+j],y=w*f[i+j+mid];
f[i+j]=x+y; f[i+j+mid]=x-y;
}
}
}
}
int main()
{
n=read(); m=read();
for (int i=0;i<=n;i++) f[i]=cp(1.0*read(),0.0);
for (int i=0;i<=m;i++) g[i]=cp(1.0*read(),0.0);
n+=m;
while (Maxn<=n) Maxn<<=1;
for (int i=0;i<Maxn;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)?(Maxn>>1):0);
fft(f,1); fft(g,1);
for (int i=0;i<Maxn;i++) f[i]*=g[i];
fft(f,-1);
for (int i=0;i<=n;i++)
printf("%d ",(int)(f[i].real()/Maxn+0.4999));
return 0;
}
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=(2<<21)+10,MOD=998244353,G=3,Ginv=332748118;
int n,m,lim,rev[N];
ll f[N],g[N];
int read()
{
int d=0; char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
return d;
}
ll fpow(ll x,int k)
{
ll ans=1;
for (;k;k>>=1,x=x*x%MOD)
if (k&1) ans=ans*x%MOD;
return ans;
}
void ntt(ll *f,int inv)
{
for (int i=0;i<lim;i++)
if (i<rev[i]) swap(f[i],f[rev[i]]);
for (int mid=1;mid<lim;mid<<=1)
{
ll tmp=fpow(inv==1 ? G : Ginv,(MOD-1)/(mid<<1));
for (int i=0;i<lim;i+=(mid<<1))
{
ll w=1;
for (int j=0;j<mid;j++,w=w*tmp%MOD)
{
int x=f[i+j],y=w*f[i+j+mid]%MOD;
f[i+j]=(x+y)%MOD; f[i+j+mid]=(x-y+MOD)%MOD;
}
}
}
}
int main()
{
n=read(); m=read();
for (int i=0;i<=n;i++) f[i]=(read()+MOD)%MOD;
for (int i=0;i<=m;i++) g[i]=(read()+MOD)%MOD;
n+=m; lim=1;
while (lim<=n) lim<<=1;
for (int i=0;i<lim;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)?(lim>>1):0);
ntt(f,1); ntt(g,1);
for (int i=0;i<lim;i++)
f[i]=f[i]*g[i]%MOD;
ntt(f,-1);
int inv=fpow(lim,MOD-2);
for (int i=0;i<=n;i++)
printf("%d ",f[i]*inv%MOD);
return 0;
}