加了蝴蝶变换优化的快速傅里叶变换。
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 using namespace std; 5 const int MAXN=1e7+10; 6 inline int read() 7 { 8 char c=getchar();int x=0,f=1; 9 while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();} 10 while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();} 11 return x*f; 12 } 13 const double Pi=acos(-1.0); 14 struct complex 15 { 16 double x,y; 17 complex (double xx=0,double yy=0){x=xx,y=yy;} 18 }a[MAXN],b[MAXN]; 19 complex operator + (complex a,complex b){ return complex(a.x+b.x , a.y+b.y);} 20 complex operator - (complex a,complex b){ return complex(a.x-b.x , a.y-b.y);} 21 complex operator * (complex a,complex b){ return complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);}//不懂的看复数的运算那部分 22 int N,M; 23 int l,r[MAXN]; 24 int limit=1; 25 void fast_fast_tle(complex *A,int type) 26 { 27 for(int i=0;i<limit;i++) 28 if(i<r[i]) swap(A[i],A[r[i]]);//求出要迭代的序列 29 for(int mid=1;mid<limit;mid<<=1)//待合并区间的中点 30 { 31 complex Wn( cos(Pi/mid) , type*sin(Pi/mid) ); //单位根 32 for(int R=mid<<1,j=0;j<limit;j+=R)//R是区间的右端点,j表示前已经到哪个位置了 33 { 34 complex w(1,0);//幂 35 for(int k=0;k<mid;k++,w=w*Wn)//枚举左半部分 36 { 37 complex x=A[j+k],y=w*A[j+mid+k];//蝴蝶效应 38 A[j+k]=x+y; 39 A[j+mid+k]=x-y; 40 } 41 } 42 } 43 } 44 int main() 45 { 46 int N=read(),M=read(); 47 for(int i=0;i<=N;i++) a[i].x=read(); 48 for(int i=0;i<=M;i++) b[i].x=read(); 49 while(limit<=N+M) limit<<=1,l++; 50 for(int i=0;i<limit;i++) 51 r[i]= ( r[i>>1]>>1 )| ( (i&1)<<(l-1) ) ; 52 // 在原序列中 i 与 i/2 的关系是 : i可以看做是i/2的二进制上的每一位左移一位得来 53 // 那么在反转后的数组中就需要右移一位,同时特殊处理一下复数 54 fast_fast_tle(a,1); 55 fast_fast_tle(b,1); 56 for(int i=0;i<=limit;i++) a[i]=a[i]*b[i]; 57 fast_fast_tle(a,-1); 58 for(int i=0;i<=N+M;i++) 59 printf("%d ",(int)(a[i].x/limit+0.5)); 60 return 0; 61 }
据说比fft更快的ntt。
1 #include<cstdio> 2 #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++) 3 #define swap(x,y) x ^= y, y ^= x, x ^= y 4 #define LL long long 5 const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118; 6 char buf[1<<21], *p1 = buf, *p2 = buf; 7 inline int read() { 8 char c = getchar(); int x = 0, f = 1; 9 while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();} 10 while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); 11 return x * f; 12 } 13 int N, M, limit = 1, L, r[MAXN]; 14 LL a[MAXN], b[MAXN]; 15 inline LL fastpow(LL a, LL k) { 16 LL base = 1; 17 while(k) { 18 if(k & 1) base = (base * a ) % P; 19 a = (a * a) % P; 20 k >>= 1; 21 } 22 return base % P; 23 } 24 inline void NTT(LL *A, int type) { 25 for(int i = 0; i < limit; i++) 26 if(i < r[i]) swap(A[i], A[r[i]]); 27 for(int mid = 1; mid < limit; mid <<= 1) { 28 LL Wn = fastpow( type == 1 ? G : Gi , (P - 1) / (mid << 1)); 29 for(int j = 0; j < limit; j += (mid << 1)) { 30 LL w = 1; 31 for(int k = 0; k < mid; k++, w = (w * Wn) % P) { 32 int x = A[j + k], y = w * A[j + k + mid] % P; 33 A[j + k] = (x + y) % P, 34 A[j + k + mid] = (x - y + P) % P; 35 } 36 } 37 } 38 } 39 int main() { 40 N = read(); M = read(); 41 for(int i = 0; i <= N; i++) a[i] = (read() + P) % P; 42 for(int i = 0; i <= M; i++) b[i] = (read() + P) % P; 43 while(limit <= N + M) limit <<= 1, L++; 44 for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1)); 45 NTT(a, 1);NTT(b, 1); 46 for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P; 47 NTT(a, -1); 48 LL inv = fastpow(limit, P - 2); 49 for(int i = 0; i <= N + M; i++) 50 printf("%d ", (a[i] * inv) % P); 51 return 0; 52 }