题目描述:
题解:
用$fft$水过(什么$ntt$我不知道)。
众所周知,$fft$精度低,$ntt$处理范围小。
所以就有了任意模数ntt神奇$fft$!
意思是这样的。比如我要算$F*G$,我可以把这两个多项式各分成两个多项式,一个表示$F_x/M$,一个表示$F_x$%$M$($M$是自己设定的阈值)。
比如说$F=a*M+b,G=c*M+d$,那么$F*G=(a*M+b)*(c*M+d)=a*c*M^2+a*d*M+b*c*M+b*d$。
然后?就水过了啊……
顺便提一下,要开$long double$。
代码:
#include<cmath> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; const int N = 500050; const long double Pi = acos(-1.0); template<typename T> inline void read(T&x) { T f = 1,c = 0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();} x = f*c; } int n,m,MOD; struct cp { long double x,y; cp(){} cp(long double x,long double y):x(x),y(y){} cp operator + (const cp&a)const{return cp(x+a.x,y+a.y);} cp operator - (const cp&a)const{return cp(x-a.x,y-a.y);} cp operator * (const cp&a)const{return cp(x*a.x-y*a.y,x*a.y+y*a.x);} }; int to[N],lim,L; void init() { lim = 1,L = 0; while(lim<=2*max(n,m))lim<<=1,L++; for(int i=1;i<lim;i++) to[i] = ((to[i>>1]>>1)|((i&1)<<(L-1))); } ll A[N],B[N],C[N]; void fft(cp*a,int len,int k) { for(int i=0;i<len;i++) if(i<to[i])swap(a[i],a[to[i]]); for(int i=1;i<len;i<<=1) { cp w0(cos(Pi/i),k*sin(Pi/i)); for(int j=0;j<len;j+=(i<<1)) { cp w(1,0); for(int o=0;o<i;o++,w=w*w0) { cp w1 = a[j+o],w2 = a[j+o+i]*w; a[j+o] = w1+w2; a[j+o+i] = w1-w2; } } } if(k==-1) for(int i=0;i<len;i++)a[i].x/=len; } cp a[N],b[N],c[N],d[N],e[N],f[N],g[N],h[N]; void mtt() { int M = 32768; for(int i=0;i<max(n,m);i++) { a[i].x = A[i]/M,b[i].x = A[i]%M; c[i].x = B[i]/M,d[i].x = B[i]%M; } fft(a,lim,1),fft(b,lim,1),fft(c,lim,1),fft(d,lim,1); for(int i=0;i<lim;i++) { e[i] = a[i]*c[i],f[i] = a[i]*d[i]; g[i] = b[i]*c[i],h[i] = b[i]*d[i]; } fft(e,lim,-1),fft(f,lim,-1),fft(g,lim,-1),fft(h,lim,-1); for(int i=0;i<lim;i++) C[i] = (((ll)(e[i].x+0.1))%MOD*M%MOD*M%MOD+((ll)(f[i].x+0.1))%MOD*M%MOD +((ll)(g[i].x+0.1))%MOD*M%MOD+((ll)(h[i].x+0.1))%MOD)%MOD; } int main() { read(n),read(m),read(MOD);n++,m++; init(); for(int i=0;i<n;i++)read(A[i]); for(int i=0;i<m;i++)read(B[i]); mtt(); for(int i=0;i<n+m-1;i++)printf("%lld ",C[i]); puts(""); return 0; }