多项式求逆是什么
对于一个(n)次多项式(F(x)),要求一个小于等于(n)次的多项式(G(x)),满足
(F(x)G(x)equiv1(mod x^n))
(mod x^n)即只考虑所有多项式的前n项。
怎么做多项式求逆
显然,当(F(x))次数为0,即只有常数项时,它的逆元就是常数项的逆元。
对于次数大于0的多项式我们假设我们已经递归求出(F(x))在(mod x^{lceilfrac{n}{2} ceil})意义下的逆(H(x))。
也就是我们有
(F(x)H(x)equiv1(mod x^{lceilfrac{n}{2} ceil}))
由(F(x)G(x)equiv1(mod x^n))易知(F(x)G(x)equiv1(mod x^{lceilfrac{n}{2} ceil}))。
两式相减得
(F(x)[H(x)-G(x)]equiv0(mod x^{lceilfrac{n}{2} ceil}))
我们有(F(x) ot=0),即(H(x)-G(x)equiv0(mod x^{lceilfrac{n}{2} ceil}))
由于我们有若(aequiv b(mod p)),则(a^2equiv b^2(mod x^2))
则两边平方得
(H(x)^2-2G(x)H(x)+G(x)^2equiv0(mod x^{2 imeslceilfrac{n}{2} ceil}))
.因为(2 imeslceilfrac{x}{2} ceilge n)所以(H(x)^2-2G(x)H(x)+G(x)^2equiv0(mod x^n))
两边同乘(F(x)),由于(F(x)G(x)equiv1(mod x^n))
(F(x)H(x)^2-2H(x)+G(x)equiv0(mod x^n))
(G(x)equiv 2H(x)-F(x)H(x)^2(mod x^n))
(G(x)equiv H(x)[2-F(x)H(x)](mod x^n))
我们可以NTT实现多项式乘法,时间复杂度(O(nlog^2n))。
code:
#include<bits/stdc++.h>
#define ci const int&
#define VAL(p,n,i) (i<n?p[i]:0)
using namespace std;
const int mod=998244353;
const int g=3;
int cpy[600010];
int POW(int x,int y){
int tot=1;
while(y)y&1?tot=1ll*tot*x%mod:0,x=1ll*x*x%mod,y>>=1;
return tot;
}
void NTT(vector<int>&f,ci l,ci len,ci op){
if(len&1)return;
for(int i=l;i<l+len;++i)cpy[i]=f[i];
int nw=l-1,ln=len>>1;
for(int i=l;i<l+ln;++i)f[i]=cpy[++nw],f[i+ln]=cpy[++nw];
NTT(f,l,ln,op),NTT(f,l+ln,ln,op);
int rt=POW(g,(mod-1)/len),t;
op?rt=POW(rt,mod-2):0;
nw=1;
for(int i=l;i<l+len;++i)cpy[i]=f[i];
for(int i=l;i<l+ln;++i,nw=1ll*nw*rt%mod)t=1ll*nw*cpy[i+ln]%mod,f[i]=(cpy[i]+t)%mod,f[i+ln]=(cpy[i]-t+mod)%mod;
}
vector<int>F;
vector<int>T;
vector<int>tmp;
vector<int>a;
vector<int>b;
vector<int>c;
int ts,sz,tg,inv;
void print(const vector<int>&x){
for(int i=0;i<x.size();++i)printf("%d ",x[i]);
}
vector<int>calc(const vector<int>&x,const vector<int>&y){//2y-x*y^2
ts=x.size()+y.size()+y.size()-2,sz=1,a.clear(),b.clear(),c.clear();
while(sz<ts)sz<<=1;
for(int i=0;i<x.size();++i)a.push_back(x[i]);
for(int i=0;i<y.size();++i)b.push_back(y[i]);
a.resize(sz),b.resize(sz),NTT(a,0,sz,0),NTT(b,0,sz,0);
for(int i=0;i<sz;++i)c.push_back((2-1ll*a[i]*b[i]%mod+mod)%mod*b[i]%mod);
NTT(c,0,sz,1),inv=POW(sz,mod-2);
for(int i=0;i<sz;++i)c[i]=1ll*c[i]*inv%mod;
return c;
}
int n,v;
vector<int>INV(const vector<int>&x){
if(x.size()==1)return T.resize(1),T[0]=POW(x[0],mod-2),T;
vector<int>G=x;
G.resize((x.size()+1)>>1),G=calc(x,INV(G)),G.resize(x.size());
return G;
}
int main(){
scanf("%d",&n);
for(int i=0;i<n;++i)scanf("%d",&v),F.push_back(v);
print(INV(F));
return 0;
}