思路
多项式除法板子
多项式除法
给出(A(x))和(B(x)),求一个(n-m)次的多项式(D(x)),一个(m-1)次多项式(R(x)),满足
[A(x)=B(x)D(x)+R(x)
]
定义(D^R(x))为多项式(D(x))系数反转的结果,可证(D^R(x)=x^nD(frac{1}{x}))
所以
[egin{align}&A(x)=B(x)D(x)+R(x)\&A(frac{1}{x})=B(frac{1}{x})D(frac{1}{x})+R(frac{1}{x})\&x^nA(frac{1}{x})=x^nB(frac{1}{x})D(frac{1}{x})+x^nR(frac{1}{x})\&A^R(x)=B^R(x)D^R(x)+x^{n-m+1}R^R(x)end{align}
]
放到模(x^{n-m+1})意义下
就消去了(R(x))的影响,然后上求逆就行了
注意反转D系数时候只反转0~n-m项系数
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
using namespace std;
const int MAXN = 300000;
const int G = 3;
const int invG = 332748118;
const int MOD = 998244353;
int n,m;
struct Poly{
int t;//次数界
int data[MAXN];
Poly(){}
Poly(int x,int val[]){
for(int i=0;i<=x;i++)
data[i]=val[i];
}
};
int pow(int a,int b){
int ans=1;
while(b){
if(b&1)
ans=(1LL*ans*a)%MOD;
a=(1LL*a*a)%MOD;
b>>=1;
}
return ans;
}
void rever(Poly &a){
for(int i=0,j=a.t;i<j;i++,j--){
swap(a.data[i],a.data[j]);
}
}
void save(Poly &a,int top){
for(int i=top+1;i<=a.t;i++)
a.data[i]=0;
a.t=top;
}
void output(Poly a){
putchar('
');
printf("a.times=%lld
",a.t);
putchar('
');
for(int i=0;i<=a.t;i++)
printf("%lld ",a.data[i]);
putchar('
');
putchar('
');
}
void NTT(Poly &a,int opt,int n){//1 DFT 0 IDFT
int lim=0;
while((1<<(lim))<n)
lim++;
n=(1<<lim);
for(int i=0;i<n;i++){
int t=0;
for(int j=0;j<lim;j++)
if((i>>j)&1)
t|=(1<<(lim-j-1));
if(i<t)
swap(a.data[i],a.data[t]);
}
for(int i=2;i<=n;i<<=1){
int len=i/2;
int tmp=pow((opt)?G:invG,(MOD-1)/i);
for(int j=0;j<n;j+=i){
int arr=1;
for(int k=j;k<j+len;k++){
int t=(1LL*a.data[k+len]*arr)%MOD;
a.data[k+len]=(a.data[k]-t+MOD)%MOD;
a.data[k]=(a.data[k]+t)%MOD;
arr=(1LL*arr*tmp)%MOD;
}
}
}
if(!opt){
int invN = pow(n,MOD-2);
for(int i=0;i<n;i++){
a.data[i]=(a.data[i]*invN)%MOD;
}
}
}
void mul(Poly &a,Poly b){//a=a*b
int num=(a.t+b.t),lim=0;
while((1<<(lim))<=((num+2)))
lim++;
lim=(1<<lim);
NTT(a,1,lim);
NTT(b,1,lim);
for(int i=0;i<lim;i++)
a.data[i]=(1LL*a.data[i]*b.data[i])%MOD;
NTT(a,0,lim);
a.t=num;
for(int i=num+1;i<lim;i++)
a.data[i]=0;
}
void Inv(Poly a,Poly &inv,int dep,int &len){//
if(dep==1){
inv.data[0]=pow(a.data[0],MOD-2);
inv.t=dep-1;
return;
}
Inv(a,inv,(dep+1)>>1,len);
static Poly tmp;
while((dep<<1)>len)
len<<=1;
for(int i=0;i<dep;i++)
tmp.data[i]=a.data[i];
for(int i=dep;i<len;i++)
tmp.data[i]=0;
NTT(tmp,1,len);
NTT(inv,1,len);
for(int i=0;i<len;i++)
inv.data[i]=1LL*inv.data[i]*((2-1LL*inv.data[i]*tmp.data[i])%MOD+MOD)%MOD;
NTT(inv,0,len);
for(int i=dep;i<len;i++)
inv.data[i]=0;
inv.t=dep-1;
}
void div(Poly a,Poly b,Poly &D,Poly &R){
static Poly tmp1,tmp2;
int Up=a.t-b.t+1,midlen=1;
tmp1=b;
rever(tmp1);
Inv(tmp1,tmp2,Up,midlen);
tmp1=a;
rever(tmp1);
mul(tmp2,tmp1);
save(tmp2,n-m);
rever(tmp2);
D=tmp2;
mul(tmp2,b);
for(int i=0;i<b.t;i++)
R.data[i]=(a.data[i]-tmp2.data[i]+MOD)%MOD;
R.t=b.t-1;
}
Poly a,b,D,R;
signed main(){
scanf("%lld %lld",&n,&m);
for(int i=0;i<=n;i++)scanf("%lld",&a.data[i]);
a.t=n;
for(int i=0;i<=m;i++)
scanf("%lld",&b.data[i]);
b.t=m;
div(a,b,D,R);
for(int i=0;i<=D.t;i++)
printf("%lld ",D.data[i]);
putchar('
');
for(int i=0;i<=R.t;i++)
printf("%lld ",R.data[i]);
putchar('
');
return 0;
}