快速傅里叶变换(FFT)
梦开始的地方
注:以上指噩梦
在刚入门的时候想必我们都学过高精度乘法
仿照高精度乘法的思想,直接将两个 (n) 次多项式相乘的时空复杂度为 (O(n^2))
这不够快,但是现在我们拿相乘的两层循环丝毫没有办法
需要另想方法来完成多项式相乘
我们知道,(n+1) 个点可以唯一确定一个 (n) 次多项式
例如 (3) 个点可以确定一个二次多项式 (ax^2+bx+c)
以下设相乘的两个多项式为 (f(x)) 和 (g(x))
次数分别为 (n) 与 (m)
那么我们只要在 (f(x)) 与 (g(x)) 上分别找 (n+m+1) 个点,计算它们的值,然后进行点值相乘,最后还原(插值)就可以了
点值相乘的时间复杂度只有 (O(n))!
现在来看,我们找到了解决方案
不过等等,在 (n) 次多项式上找 (n+m+1) 个点的时间复杂度是 (O(n^2))!
想办法使用分治!
对于一个多项式 (f(x)=a_0+a_1x+a_2x^2+dots+a_{n-1}x^{n-1})
(设:(n=2^k))
我们将奇次项分为一部分,将偶次项分为一部分
(f(x)=(a_0+a_2x^2+dots+a_{n-2}x^{n-2})+(a_1x^1+a_3x^3+dots+a_{n-1}x^{n-1}))
中间加号左边设为 (fl(x))
(fl(x)=a_0+a_2x+dots+a_{n-2}x^{n/2-1})
右边设为 (fr(x))
(fr(x)=a_1+a_3x+dots+a_{n-1}x^{n/2-1})
显然 (f(x)=fl(x^2)+xfr(x^2))
我们令 (x=omega_n^k)
我们令 (x=omega_n^{k+n/2})
此时 (f(omega_n^k)) 与 (f(omega_n^{k+n/2})) 只有一个符号的区别
并且我们注意到 (fl(x)) 与 (fr(x)) 的性质与 (f(x)) 完全相同,我们可以对 (fl(x)) 和 (fr(x)) 继续分治
至此,我们在 (O(nlogn)) 的时间内完成了将一个函数转化为点值表示的过程,可记作 ( ext{DFT}(f))
点值相乘是很简单的,之后我们需要把点值表示重新转化为系数表示
即求出 ( ext{IDFT}( ext{DFT}(f)))
设点值向量 (vec{G}= ext{DFT}(f)={y_0,y_1,dots,y_{n-1}})
此时我们将 (vec{G}) 当作系数向量再构造点值向量 (vec{H}=sumlimits_{i=0}^{n-1}G_i(omega_n^{-k})^i)
即将单位根的倒数代入 (G_0+G_1x^1+dots+G_{n-1}x^{n-1}),化简
(sumlimits_{i=0}^{n-1}(omega_n^{j-k})^i) 其实就是一个等比数列求和
通过对求和分类讨论可以得出结论:(H_i=nf_iRightarrow f_i=dfrac{H_i}{n})
所以我们可以将一开始对 (f) 进行DFT的取值取倒,再对 (G) 进行一次DFT
综上,本问题得到解决
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#define N 3000001
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define R register int
using namespace std;
const double pi=3.14159265358979323846264;
int n,m,x=1;
struct complex{
double real,img;
complex(double a=0,double b=0){real=a,img=b;}
complex operator + (const complex &b) const{return complex(real+b.real,img+b.img);}
complex operator - (const complex &b) const{return complex(real-b.real,img-b.img);}
complex operator * (const complex &b) const{return complex(real*b.real-img*b.img,real*b.img+img*b.real);}
}f[N],g[N],save[N];
fx(void,FFT)(complex *f,int len,short s){
if(len==1) return;
int hlen=len>>1;
complex *fl=f,*fr=f+hlen;
for(int i=0;i<len;i++) save[i]=f[i];
for(int i=0;i<hlen;i++){
fl[i]=save[i<<1];
fr[i]=save[i<<1|1];
}
FFT(fl,hlen,s);FFT(fr,hlen,s);
complex dw(cos(2*pi/len),sin(2*pi/len)),now(1,0);
dw.img*=s;
for(int i=0;i<hlen;i++){
save[i]=fl[i]+now*fr[i];
save[i+hlen]=fl[i]-now*fr[i];
now=now*dw;
}
for(int i=0;i<len;i++) f[i]=save[i];
}
signed main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&f[i].real);
for(int i=0;i<=m;i++) scanf("%lf",&g[i].real);
while(x<=n+m) x<<=1;
FFT(f,x,1),FFT(g,x,1);
for(int i=0;i<x;i++) f[i]=f[i]*g[i];
FFT(f,x,-1);
for(int i=0;i<=n+m;i++) printf("%d ",(int)(f[i].real/x+0.5));
}
我们发现对于每层递归,我们都进行了数组拷贝,拷贝的重要目的就是为了实现 (f(x)=fl(x^2)+xfr(x^2)) 这个式子
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
第一次变为
0 2 4 6 8 10 12 14|1 3 5 7 9 11 13 15
第二次变为
0 4 8 12|2 6 10 14|1 5 9 13|3 7 11 15
...
第一次其实就是把二进制上第 (0) 位为 (0) 的分成一组,为 (1) 的分成一组
第二次把二进制上第 (1) 位为 (0) 的分成一组,为 (1) 的分成一组
以此类推...
观察发现这其实就是二进制反转,我们可以 (O(n)) 完成
for(R i=0;i<x;i++) bf[i]=(bf[i>>1]>>1)|((i&1)?x>>1:0);
最后我们将递归实现改为迭代实现
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#define N 2600010
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define R register int
using namespace std;
const double pi=3.14159265358979323846264;
int n,m,x=1,bf[N];
struct complex{
double real,img;
complex(double a=0,double b=0){real=a,img=b;}
complex operator + (const complex &b) const{return complex(real+b.real,img+b.img);}
complex operator - (const complex &b) const{return complex(real-b.real,img-b.img);}
complex operator * (const complex &b) const{return complex(real*b.real-img*b.img,real*b.img+img*b.real);}
}f[N],g[N];
fx(void,FFT)(complex *f,short r){
R l,hl,s,i;
for(i=0;i<x;i++) if(i<bf[i]) swap(f[i],f[bf[i]]);
for(l=2,hl=1;l<=x;hl=l,l<<=1){
complex dw(cos(2*pi/l),r*sin(2*pi/l));
for(s=0;s<x;s+=l){
complex now(1,0);
for(i=s;i<s+hl;i++){
complex uni=now*f[i+hl];
f[i+hl]=f[i]-uni;f[i]=f[i]+uni;
now=now*dw;
}
}
}
}
signed main(){
scanf("%d%d",&n,&m);
for(R i=0;i<=n;i++) scanf("%lf",&f[i].real);
for(R i=0;i<=m;i++) scanf("%lf",&g[i].real);
while(x<=n+m) x<<=1;
for(R i=0;i<x;i++) bf[i]=(bf[i>>1]>>1)|((i&1)?x>>1:0);
FFT(f,1),FFT(g,1);
for(R i=0;i<x;i++) f[i]=f[i]*g[i];
FFT(f,-1);
for(R i=0;i<=n+m;i++) printf("%d ",(int)(f[i].real/x+0.5));
}
快速数论变换(NTT)
显然,因为各种三角函数参加计算,FFT会有精度丢失问题
所以我们需要找到一个单位根的替代品
但是数学家们已经证明了在复数域 (mathbb{C}) 中,单位根是唯一满足条件的数
我们所有的计算都是在模意义下的
我们可以引入原根
此时只需要证明原根满足单位根的性质
证明大都很简单,故此处只证明单位根其中一条性质: (omega_n^k=-omega_n^{k+frac{n}{2}})
换成原根就是 ((g^{frac{p-1}{n}})^k=-(g^{frac{p-1}{n}})^{k+frac{n}{2}}pmod{p})
先进行简单的化简:
我们看到原式中有一个负号,这不由得使我们想起了Wilson定理,接着我们逆推回去
拆开,由费马小定理得:(g^{frac{p-1}{2}}equiv-1pmod{p})
所以原式得证
接下来把代码中的单位根全部换掉即可
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 2500001
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define R register int
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int x=1,n,m,br[N],f[N],g[N],prp[N],invp,invx;
fx(int,pow)(int a,int b){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,short r){
R l,hl,exp,uni,s,i;
for(i=0;i<x;++i) if(i<br[i]) swap(f[i],f[br[i]]);
for(l=2,hl=1;l<=x;hl=l,l<<=1){
exp=pow(r==1?pr:invp,(mod-1)/l);
for(i=1;i<hl;i++) prp[i]=prp[i-1]*exp%mod;
for(s=0;s<x;s+=l){
for(i=0;i<hl;++i){
uni=prp[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
}
signed main(){
scanf("%d%d",&n,&m);
for(R i=0;i<=n;++i) scanf("%d",&f[i]);
for(R i=0;i<=m;++i) scanf("%d",&g[i]);
while(x<=n+m) x<<=1;
prp[0]=1,invp=pow(pr,mod-2),invx=pow(x,mod-2);
for(R i=0;i<x;++i) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
NTT(f,1),NTT(g,1);
for(R i=0;i<x;++i) (f[i]*=g[i])%=mod;
NTT(f,-1);
for(R i=0;i<=n+m;++i) printf("%d ",f[i]*invx%mod);
}
多项式乘法逆
由多项式乘法,我们可知如果要求一个多项式 (f(x)) 的乘法逆 (g(x)),就要:
递推,复杂度 (O(n^2))
这个复杂度是很难让人满意的,所以我们想结合NTT,将其复杂度优化至 (O(nlog n))
上文介绍FFT/NTT时,曾提到一个核心思想:分治
我们又可以很容易的发现常数项的逆元是很好求的
所以我们假设已经求出了 (f^prime(x)),(f^prime(x)*f(x)=1 pmod{x^frac{n}{2}},g(x)*f(x)=1pmod{x^n})
多项式最高项次数每次扩大 (1) 倍,对于每次扩大NTT计算即可
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 2000001
#define fx(l,n) inline l n
#define set(l,n) memset(l,n,sizeof(l))
#define cpy(f,t,len) memcpy(t,f,sizeof(int)*len)
#define R register int
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int x=1,n,m,br[N],f[N],g[N],prp[N],invp,invx,h[N];
fx(int,pow)(int a,int b){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,short r,const int x){
R l,hl,exp,uni,s,i;
for(R i=0;i<x;++i) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
for(i=0;i<x;++i) if(i<br[i]) swap(f[i],f[br[i]]);
for(l=2,hl=1;l<=x;hl=l,l<<=1){
exp=pow(r==1?pr:invp,(mod-1)/l);
for(i=1;i<hl;i++) prp[i]=prp[i-1]*exp%mod;
for(s=0;s<x;s+=l){
for(i=0;i<hl;++i){
uni=prp[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
if(r==-1){
invx=pow(x,mod-2);
for(R o=0;o<x;++o) (f[o]*=invx)%=mod;
}
}
fx(void,inv)(int *f,int x){
static int j[N],k[N];
h[0]=pow(f[0],mod-2);
for(int len=2,hl=1;len<=x;hl=len,len<<=1){
for(int o=0;o<hl;o++) j[o]=(h[o]<<1)%mod;
cpy(f,k,len);
NTT(h,1,len<<1);
for(R o=0;o<(len<<1);++o) (h[o]*=h[o])%=mod;
NTT(k,1,len<<1);
for(R o=0;o<(len<<1);++o) (h[o]*=k[o])%=mod;
NTT(h,-1,len<<1);
for(R o=0;o<len;++o) h[o]=(j[o]-h[o]+mod)%mod;
memset(h+len,0,sizeof(int)*len);
}
}
signed main(){
scanf("%d",&n);
for(R i=0;i<n;++i) scanf("%d",&f[i]);
while(x<n) x<<=1;
prp[0]=1,invp=pow(pr,mod-2);
inv(f,x);
for(R o=0;o<n;++o) cout<<h[o]<<" ";
}
多项式对数函数(多项式ln)
直接推导:
按步骤直接做就可以
- 对 (A(x)) 求乘法逆
- 对 (A(x)) 求导
- NTT相乘
- 对 (dfrac{A^prime(x)}{A(x)}) 积分
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 1000001
#define INF 1100000000
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define R register int
#define C const
#define int long long
using namespace std;
C int mod=998244353,pr=3;
int n,br[N],ppr[N],x=1,invp,invx,A[N],B[N];
fx(int,gi)(){
char c=getchar();int s=0,f=1;
while(c<'0'||c>'9'){
if(c=='-') f=-1;
c=getchar();
}
while(c>='0'&&c<='9') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
return s*f;
}
fx(int,pow)(int a,int b=mod-2){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,C short r,C int x){
R len,hl,exp,uni,s,i;
for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
for(len=2,hl=1;len<=x;hl=len,len<<=1){
exp=pow(r==1?pr:invp,(mod-1)/len);
for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*exp%mod;
for(s=0;s<x;s+=len){
for(i=0;i<hl;i++){
uni=ppr[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
if(r==-1){
invx=pow(x);
for(i=0;i<x;i++) (f[i]*=invx)%=mod;
}
}
fx(void,INV)(int *f,const int x){
static int le[N],ri[N],inv[N];
inv[0]=pow(f[0]);
for(R len=2,hl=1,o;len<=x;hl=len,len<<=1){
for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
cpy(f,ri,int,len);
NTT(inv,1,len<<1);NTT(ri,1,len<<1);
for(o=0;o<(len<<1);o++) (((inv[o]*=inv[o])%=mod)*=ri[o])%=mod;
NTT(inv,-1,len<<1);
for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
set(inv+len,0,int,len);
}
cpy(inv,f,int,n);
}
fx(void,DER)(int *f,C int len){
for(int i=1;i<len;i++) f[i-1]=f[i]*i%mod;
f[len-1]=0;
}
fx(void,INT)(int *f,C int len){
for(R i=len-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
f[0]=0;
}
signed main(){
n=gi();
for(R i=0;i<n;i++) A[i]=gi();
cpy(A,B,int,n);
while(x<n) x<<=1;
invp=pow(pr);ppr[0]=1;
INV(A,x);DER(B,n);
while(x<n+n) x<<=1;
NTT(A,1,x);NTT(B,1,x);
for(R i=0;i<x;i++) (A[i]*=B[i])%=mod;
NTT(A,-1,x);INT(A,n);
for(R i=0;i<n;i++) printf("%lld ",A[i]);
}
多项式指数函数(多项式exp)
多项式牛顿迭代在这里不做过多叙述
此处 (exp A(x)equiv B(x)pmod{x^n})
两边取对数得 (A(x)equivln B(x)pmod{x^n})
即 (ln B(x)-A(x)equiv0pmod{x^n})
利用多项式牛顿迭代结果:
(A(x)) 是确定的,故 (F(B(x))=ln B(x)-A(x),F^prime(B(x))=dfrac1{B(x)})
(B(x)=B_0(x)-B_0(x)(ln B_0(x)-A(x))=B_0(x)(1-ln B_0(x)+A(x)))
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define int long long
#define R register int
#define C const
using namespace std;
C int mod=998244353,pr=3;
int f[N],ppr[N],br[N],expcp[N],exp[N],n,hx=1,x=1,invx,invp;
fx(int,gi)(){
char c=getchar();int s=0,f=1;
while(c>'9'||c<'0'){
if(c=='-') f=-f;
c=getchar();
}
while(c>='0'&&c<='9') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
return s*f;
}
fx(int,pow)(int a,int b=mod-2){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
R i,len,hl,s,expr,uni;
for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
for(len=2,hl=1;len<=x;hl=len,len<<=1){
expr=pow(r?pr:invp,(mod-1)/len);
for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
for(s=0;s<x;s+=len){
for(i=0;i<hl;i++){
uni=ppr[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
if(!r){
invx=pow(x);
for(i=0;i<x;i++) (f[i]*=invx)%=mod;
}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
inv[0]=pow(f[0]);
for(R dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
set(le+hl,0,int,len+hl);set(ri+len,0,int,len);
for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
cpy(f,ri,int,len);
NTT(inv,1,dl),NTT(ri,1,dl);
for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
NTT(inv,0,dl);
for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
set(inv+len,0,int,len);
}
cpy(inv,f,int,x);set(inv,0,int,x);
}
fx(void,DER)(int *f,C int x){
for(R i=1;i<x;i++) f[i-1]=f[i]*i%mod;
f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
for(R i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
f[0]=0;
}
int lncp[N];
fx(void,LN)(int *f,C int hx,C int x){
cpy(f,lncp,int,hx);set(lncp+hx,0,int,hx);
INV(f,hx);DER(lncp,hx);
NTT(f,1,x);NTT(lncp,1,x);
for(R i=0;i<x;i++) (f[i]*=lncp[i])%=mod;
NTT(f,0,x);INT(f,hx);
set(f+hx,0,int,hx);
}
fx(void,EXP)(int *f,C int x){
exp[0]=1;
for(R hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
set(expcp,0,expcp,1);
cpy(exp,expcp,int,hl);
LN(expcp,len,dl);
for(i=0;i<len;i++) expcp[i]=(f[i]-expcp[i]+mod)%mod;
(expcp[0]+=1)%mod;
NTT(exp,1,dl);NTT(expcp,1,dl);
for(i=0;i<dl;i++) (exp[i]*=expcp[i])%=mod;
NTT(exp,0,dl);
set(exp+len,0,int,len);
}
cpy(exp,f,int,n);
}
signed main(){
ppr[0]=1;invp=pow(pr);
n=gi();hx=1,x=1;
set(f,0,f,1);
for(R i=0;i<n;i++) f[i]=gi();
while(x<n) x<<=1;
EXP(f,x);
for(R i=0;i<n;i++) printf("%lld ",f[i]);
printf("
");
}
多项式快速幂
(A^k(x)equiv B(x)pmod{x^n})
两边取对数:(ln A^k(x)equivln B(x)pmod{x^n})
(kln A(x)equivln B(x)pmod{x^n})
( ext e^{kln A(x)}equiv B(x)pmod{x^n})
先取对数,之后 (exp) 即可
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define int long long
#define R register int
#define C const
using namespace std;
C int mod=998244353,pr=3;
int f[N],ppr[N],br[N],expcp[N],exp[N],n,x=1,invx,invp,k;
fx(int,gi)(){
char c=getchar();int s=0,f=1;
while(c>'9'||c<'0'){
if(c=='-') f=-f;
c=getchar();
}
while(c>='0'&&c<='9') s=((s<<3)+(s<<1)+(c-'0'))%mod,c=getchar();
return s*f%mod;
}
fx(int,pow)(int a,int b=mod-2){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
R i,len,hl,s,expr,uni;
for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
for(len=2,hl=1;len<=x;hl=len,len<<=1){
expr=pow(r?pr:invp,(mod-1)/len);
for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
for(s=0;s<x;s+=len){
for(i=0;i<hl;i++){
uni=ppr[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
if(!r){
invx=pow(x);
for(i=0;i<x;i++) (f[i]*=invx)%=mod;
}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
inv[0]=pow(f[0]);
for(R dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
set(le+hl,0,int,len+hl);set(ri+len,0,int,len);
for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
cpy(f,ri,int,len);
NTT(inv,1,dl),NTT(ri,1,dl);
for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
NTT(inv,0,dl);
for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
set(inv+len,0,int,len);
}
cpy(inv,f,int,x);set(inv,0,int,x);
}
fx(void,DER)(int *f,C int x){
for(R i=1;i<x;i++) f[i-1]=f[i]*i%mod;
f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
for(R i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
f[0]=0;
}
int lncp[N];
fx(void,LN)(int *f,C int hx,C int x){
cpy(f,lncp,int,hx);set(lncp+hx,0,int,hx);
INV(f,hx);DER(lncp,hx);
NTT(f,1,x);NTT(lncp,1,x);
for(R i=0;i<x;i++) (f[i]*=lncp[i])%=mod;
NTT(f,0,x);INT(f,hx);
set(f+hx,0,int,hx);
}
fx(void,EXP)(int *f,C int x){
exp[0]=1;
for(R hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
set(expcp,0,expcp,1);
cpy(exp,expcp,int,hl);
LN(expcp,len,dl);
for(i=0;i<len;i++) expcp[i]=(f[i]-expcp[i]+mod)%mod;
(expcp[0]+=1)%mod;
NTT(exp,1,dl);NTT(expcp,1,dl);
for(i=0;i<dl;i++) (exp[i]*=expcp[i])%=mod;
NTT(exp,0,dl);
set(exp+len,0,int,len);
}
cpy(exp,f,int,n);
}
signed main(){
ppr[0]=1;invp=pow(pr);
n=gi();k=gi();
set(f,0,f,1);
for(R i=0;i<n;i++) f[i]=gi();
while(x<n) x<<=1;
LN(f,x,x<<1);
for(R i=0;i<n;i++) (f[i]*=k)%=mod;
EXP(f,x);
for(R i=0;i<n;i++) printf("%lld ",f[i]);
printf("
");
}
多项式开根
尝试使用牛顿迭代,(F(G_0(x))=B_0^2(x)-A(x))
直接无脑推式子就可以
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define M 5001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define R register
#define C const
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int br[N],x=1,n,f[N],invp,invx,invt,ppr[N];
fx(int,gi)(){
R char c=getchar();R int s=0,f=1;
while(c>'9'||c<'0'){
if(c=='-') f=-f;
c=getchar();
}
while(c<='9'&&c>='0') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
return s*f;
}
fx(int,pow)(int a,int b=mod-2){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
R int i,len,hl,s,uni,expr;
for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
for(hl=1,len=2;len<=x;hl=len,len<<=1){
expr=pow(r?pr:invp,(mod-1)/len);
for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
for(s=0;s<x;s+=len){
for(i=0;i<hl;i++){
uni=ppr[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
if(!r){
invx=pow(x);
for(i=0;i<x;i++) (f[i]*=invx)%=mod;
}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
inv[0]=pow(f[0]);
for(R int dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
cpy(f,ri,int,len);
NTT(inv,1,dl),NTT(ri,1,dl);
for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
NTT(inv,0,dl);
for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
set(inv+len,0,int,len);
}
cpy(inv,f,int,x);set(inv,0,int,x);
set(le,0,int,x>>1);set(ri,0,int,x<<1);
}
int sqrt[N],scp[N],sac[N];
fx(void,SQRT)(int *f,int x){
sqrt[0]=1;
for(R int hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
cpy(sqrt,scp,int,len);cpy(f,sac,int,len);
INV(scp,len);
NTT(sac,1,dl);NTT(scp,1,dl);
for(i=0;i<dl;i++) (sac[i]*=scp[i])%=mod;
NTT(sac,0,dl);
for(i=0;i<len;i++) sqrt[i]=(sqrt[i]+sac[i])*invt%mod;
}
cpy(sqrt,f,int,x);
}
signed main(){
ppr[0]=1;invp=pow(pr);invt=pow(2);
n=gi();
for(R int i=0;i<n;i++) f[i]=gi();
while(x<n) x<<=1;
SQRT(f,x);
for(R int i=0;i<n;i++) printf("%lld ",f[i]);
}
多项式三角函数
根据欧拉公式:
类似NTT使用原根表示单位根,在模意义下 (i=g^frac{p-1}4)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define int long long
#define R register int
#define C const
using namespace std;
C int mod=998244353,pr=3;
int f[N],g[N],ppr[N],br[N],expcp[N],exp[N],n,x=1,invx,invp,ty,I,invt,invi;
fx(int,gi)(){
char c=getchar();int s=0,f=1;
while(c>'9'||c<'0'){
if(c=='-') f=-f;
c=getchar();
}
while(c>='0'&&c<='9') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
return s*f;
}
fx(int,pow)(int a,int b=mod-2){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
R i,len,hl,s,expr,uni;
for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
for(len=2,hl=1;len<=x;hl=len,len<<=1){
expr=pow(r?pr:invp,(mod-1)/len);
for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
for(s=0;s<x;s+=len){
for(i=0;i<hl;i++){
uni=ppr[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
if(!r){
invx=pow(x);
for(i=0;i<x;i++) (f[i]*=invx)%=mod;
}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
inv[0]=pow(f[0]);
for(R dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
cpy(f,ri,int,len);
NTT(inv,1,dl),NTT(ri,1,dl);
for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
NTT(inv,0,dl);
for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
set(inv+len,0,int,len);
}
cpy(inv,f,int,x);set(inv,0,int,x);
set(le,0,int,x>>1);set(ri,0,int,x<<1);
}
fx(void,DER)(int *f,C int x){
for(R i=1;i<x;i++) f[i-1]=f[i]*i%mod;
f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
for(R i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
f[0]=0;
}
int lncp[N];
fx(void,LN)(int *f,C int hx,C int x){
cpy(f,lncp,int,hx);set(lncp+hx,0,int,hx);
INV(f,hx);DER(lncp,hx);
NTT(f,1,x);NTT(lncp,1,x);
for(R i=0;i<x;i++) (f[i]*=lncp[i])%=mod;
NTT(f,0,x);INT(f,hx);
set(f+hx,0,int,hx);
}
fx(void,EXP)(int *f,C int x){
set(expcp,0,expcp,1);set(exp,0,exp,1);exp[0]=1;
for(R hl=1,len=2,dl=4,i;len<=x;hl=len,len=dl,dl<<=1){
cpy(exp,expcp,int,hl);set(expcp+hl,0,int,hl);
LN(expcp,len,dl);
for(i=0;i<len;i++) expcp[i]=(f[i]-expcp[i]+mod)%mod;
(expcp[0]+=1)%mod;
NTT(exp,1,dl);NTT(expcp,1,dl);
for(i=0;i<dl;i++) (exp[i]*=expcp[i])%=mod;
NTT(exp,0,dl);
set(exp+len,0,int,len);
}
cpy(exp,f,int,x);
}
fx(void,SIN)(int *f,int *g,C int x){
EXP(f,x);EXP(g,x);
for(R i=0;i<x;i++) f[i]=(f[i]-g[i]+mod)*invt%mod*invi%mod;
}
fx(void,COS)(int *f,int *g,C int x){
EXP(f,x);EXP(g,x);
for(R i=0;i<x;i++) f[i]=(f[i]+g[i])*invt%mod;
}
signed main(){
ppr[0]=1;invp=pow(pr);I=pow(pr,(mod-1)/4);invt=pow(2);invi=pow(I);
n=gi();ty=gi();
for(R i=0;i<n;i++) f[i]=gi(),(f[i]*=I)%=mod,g[i]=mod-f[i];
while(x<n) x<<=1;
if(ty) COS(f,g,x);
else SIN(f,g,x);
for(R i=0;i<n;i++) printf("%lld ",f[i]);
printf("
");
}
多项式反三角函数
对于 (arcsin x) 与 (arctan x),我们可以将其求导,变成初等函数能表示的形式,然后积分
虽然字面上看这样没什么意义,求导再积分嘛,就和+1再-1一样,但是求导跟加减可不一样
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#define N 1000001
#define M 5001
#define INF 1100000000
#define Kafuu return
#define Chino 0
#define fx(l,n) inline l n
#define set(l,n,ty,len) memset(l,n,sizeof(ty)*len)
#define cpy(f,t,ty,len) memcpy(t,f,sizeof(ty)*len)
#define R register
#define C const
#define int long long
using namespace std;
const int mod=998244353,pr=3;
int br[N],x=1,n,f[N],invp,invx,invt,ppr[N],g[N],ty;
fx(int,gi)(){
R char c=getchar();R int s=0,f=1;
while(c>'9'||c<'0'){
if(c=='-') f=-f;
c=getchar();
}
while(c<='9'&&c>='0') s=(s<<3)+(s<<1)+(c-'0'),c=getchar();
return s*f;
}
fx(int,pow)(int a,int b=mod-2){
int ans=1;
while(b){
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
b>>=1;
}
return ans;
}
fx(void,NTT)(int *f,C bool r,C int x){
R int i,len,hl,s,uni,expr;
for(i=0;i<x;i++) br[i]=(br[i>>1]>>1)|((i&1)?x>>1:0);
for(i=0;i<x;i++) if(i<br[i]) swap(f[i],f[br[i]]);
for(hl=1,len=2;len<=x;hl=len,len<<=1){
expr=pow(r?pr:invp,(mod-1)/len);
for(i=1;i<hl;i++) ppr[i]=ppr[i-1]*expr%mod;
for(s=0;s<x;s+=len){
for(i=0;i<hl;i++){
uni=ppr[i]*f[i|s|hl]%mod;
f[i|s|hl]=(f[i|s]-uni+mod)%mod;
f[i|s]=(f[i|s]+uni)%mod;
}
}
}
if(!r){
invx=pow(x);
for(i=0;i<x;i++) (f[i]*=invx)%=mod;
}
}
int le[N],ri[N],inv[N];
fx(void,INV)(int *f,C int x){
inv[0]=pow(f[0]);
for(R int dl=4,len=2,hl=1,o;len<=x;hl=len,len=dl,dl<<=1){
for(o=0;o<hl;o++) le[o]=(inv[o]<<1)%mod;
cpy(f,ri,int,len);
NTT(inv,1,dl),NTT(ri,1,dl);
for(o=0;o<dl;o++) (inv[o]*=inv[o]*ri[o]%mod)%=mod;
NTT(inv,0,dl);
for(o=0;o<len;o++) inv[o]=(le[o]-inv[o]+mod)%mod;
set(inv+len,0,int,len);
}
cpy(inv,f,int,x);set(inv,0,int,x);
set(le,0,int,x>>1);set(ri,0,int,x<<1);
}
fx(void,DER)(int *f,C int x){
for(R int i=1;i<x;i++) f[i-1]=f[i]*i%mod;
f[x-1]=0;
}
fx(void,INT)(int *f,C int x){
for(R int i=x-1;i>=1;i--) f[i]=f[i-1]*pow(i)%mod;
f[0]=0;
}
int sqrt[N],scp[N],sac[N];
fx(void,SQRT)(int *f,C int x){
sqrt[0]=1;
for(R int len=2,dl=4,i;len<=x;len=dl,dl<<=1){
cpy(sqrt,scp,int,len);cpy(f,sac,int,len);
INV(scp,len);
NTT(sac,1,dl);NTT(scp,1,dl);
for(i=0;i<dl;i++) (sac[i]*=scp[i])%=mod;
NTT(sac,0,dl);
for(i=0;i<len;i++) sqrt[i]=(sqrt[i]+sac[i])*invt%mod;
}
cpy(sqrt,f,int,x);
}
fx(void,ARCSIN)(int *f,C int x){
cpy(f,g,int,x);
DER(g,x);NTT(f,1,x<<1);NTT(g,1,x<<1);
for(R int i=0;i<(x<<1);i++) (f[i]*=f[i])%=mod;
NTT(f,0,x<<1);
set(f+x,0,int,x);
for(R int i=0;i<x;i++) f[i]=mod-f[i];
(f[0]+=1)%=mod;
SQRT(f,x);INV(f,x);NTT(f,1,x<<1);
for(R int i=0;i<(x<<1);i++) (f[i]*=g[i])%=mod;
NTT(f,0,x<<1);INT(f,x);
}
fx(void,ARCTAN)(int *f,C int x){
cpy(f,g,int,x);
DER(g,x);NTT(f,1,x<<1);NTT(g,1,x<<1);
for(R int i=0;i<(x<<1);i++) (f[i]*=f[i])%=mod;
NTT(f,0,x<<1);set(f+x,0,int,x);
(f[0]+=1)%=mod;
INV(f,x);NTT(f,1,x<<1);
for(R int i=0;i<(x<<1);i++) (f[i]*=g[i])%=mod;
NTT(f,0,x<<1);INT(f,x);
}
signed main(){
ppr[0]=1;invp=pow(pr);invt=pow(2);
n=gi();ty=gi();
for(R int i=0;i<n;i++) f[i]=gi();
while(x<n) x<<=1;
if(ty) ARCTAN(f,x);
else ARCSIN(f,x);
for(R int i=0;i<n;i++) printf("%lld ",f[i]);
}