问题:
已知$A=a_{0..n-1}$, $B=b_{0..n-1}$, 求$C=c_{0..2n-2}$,使:
$$c_i = sum_{j=0}^ia_jb_{i-j}$$
定义$C$是$A$,$B$的卷积,记作
$$C = A * B$$
例如多项式乘法等。
朴素做法是按照定义枚举$i$和$j$,但这样时间复杂度是$O(n^2)$.
能不能使时间复杂度降下来呢?
点值表示法:
我们把$A$,$B$,$C$看作多项式。
即:
$$A(x) = sum_{i=0}^{n-1}a_ix^i$$
将$A=left{(x_1,A(x_1)), (x_2,A(x_2)), (x_3,A(x_3))... ight}$叫做$A$的点值表示法。
由于$A$是$n-1$次多项式,我们恰好需要$n$个点值来确定它。
那么使用点值表示法做多项式乘法就很简单了:对应项相乘(这样的话要将$n$扩大一倍,因为$C$的次数约为$2n$)。
那么,如何将$A$和$B$转换成点值表示法,再将$C$转化回系数表示法(即最初的表示方法)呢?
如果任取$n$个点,按照定义计算,那么还是$O(n^2)$的。
这样就要用到快速傅里叶变换。
快速傅里叶变换:
既然任取$n$个点,按照定义计算太慢,就要找一些特殊点。
我们用$n$个$n$次单位复数根($1$的$n$次方根,涉及到复数,$1$的方根不止$1$和$-1$)来计算:
根据欧拉公式,$e^{i heta}=cos heta+isin heta$(其中$i$是虚数单位),那么$e^{2pi i}=cos(2pi)+isin(2pi)=1$.
所以1的n次方根是$omega_n^k=e^{frac{2kpi i}n}qquad0leq k<n$。
其中$omega_n^1 = e^{frac{2pi i}n}$是主$n$次单位根,那么所有$n$次单位复数根都是它的幂。
我们要求出$A(omega_n^k)$,就要采用分治思想。
我们将奇偶系数分离(先假设n为偶数),即定义
$$A_0(x)=a_0 + a_2* x + a_4 * x^2 +cdots=sum_{i=0}^{frac{n}2-1}a_{2i}x^i$$
$$A_1(x)=a_1 + a_3* x + a_5 * x^2 +cdots=sum_{i=0}^{frac{n}2-1}a_{2i+1}x^i$$
那么$A(x)=A_0(x^2) + xA_1(x^2)$。
要计算$A(omega_n^k)=A_0left[(omega_n^k)^2 ight] + omega_n^kA_1left[(omega_n^k)^2 ight]$,
就要用到$(omega_n^k)^2 = omega_{n/2}^{k\,mod (n/2)}$(证略)。
所以$A(omega_n^k)=A_0left(omega_{n/2}^{k\,mod (n/2)} ight) + omega_n^kA_1left(omega_{n/2}^{k\,mod (n/2)} ight)$
我们发现$A_0$,$A_1$都是$n/2$项的,且只需要算$omega_{n/2}^k$的值,那么这就和开始的问题一样了,可以分治。
边界也很容易:$n=1$的时候$A_0$本身就是值。
合并解。
$$A(omega_n^k)=A_0left(omega_{n/2}^{k\,mod (n/2)} ight) + omega_n^kA_1left(omega_{n/2}^{k\,mod (n/2)} ight)$$
那么可以$A(omega_n^k), A(omega_n^{k+n/2})$一起算$(0leq k<n/2)$ :
令$u = A_0(omega_{n/2}^k), t = omega_n^kA_1(omega_{n/2}^k)$,
那么
$$A(omega_n^k)=u + t$$
$$egin{aligned}& quad A(omega_n^{k+n/2}) \
&= A_0(omega_{n/2}^k) + omega_n^{k+n/2}A_1(omega_{n/2}^k)\
&= A_0(omega_{n/2}^k) + omega_n^komega_n^{n/2}A_1(omega_{n/2}^k)\
&= A_0(omega_{n/2}^k) - omega_n^kA_1(omega_{n/2}^k)\
&= u-tend{aligned}$$
所以这样就能算出$A$的点值表示法。
一个问题:分治要求$n$是$2$的幂,不是怎么办? 补$0$, 直到$n$是$2$的幂。
时间复杂度:
$$T(n)=2T(n/2)+O(n)$$
直接观察或者应用主定理都可得出$T(n)=O(nlogn)$
剩下的问题:如何把C转化回系数表示法。
逆变换:
我们把C做一遍快速傅立叶变换,只是求的是$omega_n^n, omega_n^{n-1}, cdot,omega_n^1$的值而不是$omega_n^0, omega_n^1, cdot,omega_n^{n-1}$的值,最后每一项除以n即可。
证明(实际上可以利用逆矩阵,但我就写的麻烦一点吧qwq):
这样我实际上是求了$c_i=frac1nsum_{k=0}^{n-1}(omega_n^{-i})^kC(omega_n^k)$。我们来直接证明这是正确的。
$$egin{aligned} frac1nsum_{k=0}^{n-1}(omega_n^{-i})^kC(omega_n^k)&=frac1nsum_{k=0}^{n-1}omega_n^{-ki}sum_{j=0}^{n-1}c_j(omega_n^k)^j\ &=frac1nsum_{j=0}^{n-1}c_jsum_{k=0}^{n-1}omega_n^{jk}omega_n^{-ki}\ &=frac1nsum_{j=0}^{n-1}c_jsum_{k=0}^{n-1}(omega_n^{j-i})^k end{aligned}$$
$sum_{k=0}^{n-1}(omega_n^{j-i})^k$是一个公比为$omega_n^{j-i}$的等差数列。在$i=j$时$n$项都为$1$,显然其值为$n$;$j eq i$的时候根据等差数列求和公式他就等于
$$frac{1-(omega_n^{j-i})^n}{1-omega_n^{j-i}}$$
而$(omega_n^{j-i})^n$等于$1$,所以上式就是$0$。于是 $$frac1nsum_{k=0}^{n-1}(omega_n^{-i})^kC(omega_n^k)=frac1nsum_{j=0}^{n-1}c_jsum_{k=0}^{n-1}(omega_n^{j-i})^k=frac1nsum_{j=0}^{n-1}c_j imes n[i == j]=c_i$$ 证毕。
1 #include <algorithm> 2 #include <cmath> 3 const double pi = acos(-1.0); 4 struct complex{ 5 double real, impl; 6 complex(double r = 0.0, double i = 0.0) : real(r), impl(i) {} 7 friend complex operator+(const complex &a, const complex &b) { 8 return complex(a.real + b.real, a.impl + b.impl); 9 } 10 friend complex operator-(const complex &a, const complex &b) { 11 return complex(a.real - b.real, a.impl - b.impl); 12 } 13 friend complex operator*(const complex &a, const complex &b) { 14 return complex(a.real * b.real - a.impl * b.impl, a.impl * b.real + b.impl * a.real); 15 } 16 friend complex operator/(const complex &a, double b) { 17 return complex(a.real / b, a.impl / b); 18 } 19 }; 20 using std::swap; 21 void FFT(complex* P, int len, int opt) { 22 for (int i = 1, j = 0, k; i < len; ++i) { 23 for (k = len >> 1; j & k; k >>= 1) j ^= k; 24 j ^= k; 25 if (i < j) swap(P[i], P[j]); 26 } 27 for (int h = 2; h <= len; h <<= 1) { 28 complex wn = complex(cos(opt * 2 * pi / h), sin(opt * 2 * pi / h)); 29 for (int j = 0; j < len; j += h) { 30 complex w = complex(1.0, .0); 31 for (int t = 0; t < h / 2; ++t, w = w * wn) { 32 complex tmp1 = P[t + j], tmp2 = P[t + j + h / 2]; 33 P[t + j] = tmp1 + tmp2 * w; 34 P[t + j + h / 2] = tmp1 - tmp2 * w; 35 } 36 } 37 } 38 if (opt == -1) 39 for (int i = 0; i < len; ++i) 40 P[i] = P[i] / len; 41 }