此处不作题解。。只有细节
记住两个多项式相乘的时候必须 在高位留出空,
也就是代码里的memset(b+n,0,n<<2);
然后 在%x^n的操作结束后一定要把>=n的部分清零,保持答案在%x^n意义下。(当然,如果能记得,之后用的时候再清零也行)
还有就是从本题学到两个卡常技巧:
一、
1 ① Time:24732ms 2 void ntt(int a[],int n,int d=1){ 3 for (int i=0;i<n;++i) if (i<rev[i]) swap(a[i],a[rev[i]]); 4 for (int m=2;m<=n;m<<=1){ 5 int k=m>>1; LL wm=po(G,d*(mo-1)/m); 6 7 for (int i=1;i<k;++i) w[i]=w[i-1]*wm%mo; //瓶颈在这? 8 9 for (int i=0;i<n;i+=m){ 10 for (int j=i;j<i+k;++j){ 11 int u=a[j],v=1LL*w[j-i]*a[j+k]%mo; 12 a[j]=u+v<mo?u+v:u+v-mo; 13 a[j+k]=u<v?mo-v+u:u-v; 14 } 15 } 16 } 17 if (d==-1){ 18 LL x=po(n,mo-2); 19 for (int i=0;i<n;++i) a[i]=x*a[i]%mo; 20 } 21 } 22 23 ② Time:31844ms 24 void ntt(int a[],int n,int d=1){ 25 for (int i=0;i<n;++i) if (i<rev[i]) swap(a[i],a[rev[i]]); 26 for (int m=2;m<=n;m<<=1){ 27 int k=m>>1,wm=po(G,d*(mo-1)/m); 28 for (int i=0;i<n;i+=m){ 29 LL w=1; 30 for (int j=i;j<i+k;++j){ 31 int u=a[j],v=w*a[j+k]%mo; 32 a[j]=u+v<mo?u+v:u+v-mo; 33 a[j+k]=u<v?mo-v+u:u-v; 34 w=w*wm%mo; 35 } 36 } 37 } 38 if (d==-1){ 39 LL x=po(n,mo-2); 40 for (int i=0;i<n;++i) a[i]=x*a[i]%mo; 41 } 42 }
二、
1 ① Time:21844ms 2 int w[N]; 3 int wm=po(G,d*(mo-1)/m); 4 for (int i=1;i<k;++i) w[i]=1LL*w[i-1]*wm%mo; 5 6 ② Time:24732ms 所以瓶颈应该是LL的读写速度(吧。。。) 7 LL w[N]; 8 LL wm=po(G,d*(mo-1)/m); 9 for (int i=1;i<k;++i) w[i]=w[i-1]*wm%mo; 10 11 ③ Time:24696ms 这个是为了证明并不是LL*LL和LL*int的瓶颈 12 LL w[N]; 13 int wm=po(G,d*(mo-1)/m); 14 for (int i=1;i<k;++i) w[i]=w[i-1]*wm%mo;
贴代码
1 #include <bits/stdc++.h> 2 #define LL long long 3 using namespace std; 4 const int mo=998244353; 5 const int G=3; 6 const int N=300000; 7 int n,m,rev[N],x,b[N],c[N],t1[N],t2[N],t3[N],w[N]; 8 int po(int x,int y){ 9 if (y<0) y=y%(mo-1)+mo-1; int z=1; 10 for(;y;y>>=1,x=1LL*x*x%mo) if (y&1) z=1LL*z*x%mo; 11 return z; 12 } 13 void ntt(int a[],int n,int d=1){ 14 for (int i=0;i<n;++i) if (i<rev[i]) swap(a[i],a[rev[i]]); 15 for (int m=2;m<=n;m<<=1){ 16 int k=m>>1,wm=po(G,d*(mo-1)/m); 17 for (int i=1;i<k;++i) w[i]=1LL*w[i-1]*wm%mo; 18 for (int i=0;i<n;i+=m){ 19 for (int j=i;j<i+k;++j){ 20 int u=a[j],v=1LL*w[j-i]*a[j+k]%mo; 21 a[j]=u+v<mo?u+v:u+v-mo; 22 a[j+k]=u<v?mo-v+u:u-v; 23 } 24 } 25 } 26 if (d==-1){ 27 LL x=po(n,mo-2); 28 for (int i=0;i<n;++i) a[i]=x*a[i]%mo; 29 } 30 } 31 void inv(int *a,int n,int *b){ 32 if (n==1){ 33 b[0]=po(a[0],mo-2); 34 b[1]=0; return; 35 } 36 inv(a,n>>1,b); int m=n<<1; 37 for (int i=0;i<n;++i) t3[i]=a[i]; 38 memset(b+n,0,n<<2); memset(t3+n,0,n<<2); 39 for (int i=0;i<m;++i) rev[i]=(rev[i>>1]>>1)+(i&1)*(m>>1); 40 ntt(t3,m); ntt(b,m); 41 for (int i=0;i<m;++i) b[i]=(2LL+mo-1LL*t3[i]*b[i]%mo)*b[i]%mo; 42 ntt(b,m,-1); 43 memset(b+n,0,n<<2); 44 } 45 void gen(int *a,int n,int *b){ 46 if (n==1){ 47 b[0]=1; b[1]=0; return; 48 } 49 gen(a,n>>1,b); int m=n<<1; 50 for (int i=0;i<n;++i) t1[i]=a[i]; 51 memset(t1+n,0,n<<2); 52 inv(b,n,t2); 53 for (int i=0;i<m;++i) rev[i]=(rev[i>>1]>>1)+(i&1)*(m>>1); 54 ntt(t1,m); ntt(t2,m); 55 for (int i=0;i<m;++i) t1[i]=1LL*t1[i]*t2[i]%mo; 56 ntt(t1,m,-1); LL x=mo+1>>1; 57 for (int i=0;i<n;++i) b[i]=x*(b[i]+t1[i])%mo; 58 memset(b+n,0,n<<2); 59 } 60 int main(){ 61 scanf("%d%d",&n,&m); 62 for (int i=1;i<=n;++i){ 63 scanf("%d",&x); 64 if (x<=m) c[x]=mo-4; 65 } 66 n=1; w[0]=1; 67 while (n<=m) n<<=1; 68 c[0]=1; gen(c,n,b); 69 if (b[0]) (++b[0])%=mo; 70 else{ 71 for (int i=0;i<n;++i) b[i]=mo-b[i]; 72 b[0]?--b[0]:b[0]=mo-1; 73 } 74 inv(b,n,c); 75 for (int i=1;i<=m;++i) printf("%d ",(c[i]<<1)%mo); 76 return 0; 77 }