基本
通俗地说, 系数表达 → 点值表达, 称为 DFT, 点值表达 → 系数表达, 称为 IDFT。
FFT 通过取某些特殊的 x 的点值来加速 DFT 和 IDFT。
考虑点值表示下的多项式乘法:
[f(x) = (x_0,f(x_0)),(x_1,f(x_1)),cdots,(x_n,f(x_n))
\
g(x) = (x_0,g(x_0)),(x_1,g(x_1)),cdots,(x_n,g(x_n))
\
(fcdot g)(x) = f(x)cdot g(x)
\
(fcdot g)(x) = (x_0,f(x_0)g(x_0)), (x_1,f(x_1)g(x_1)), cdots,(x_n,f(x_n)g(x_n))
]
明显是 O(n) 的。
如此,通过 FFT, 可以实现快速的多项式乘法。
分治结构
[egin{align}
f(x) &= a_0 + a_1x+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7
\
&= (a_0+a_2x^2+a_4x^4+a_6x^6) + xcdot(a_1+a_3x^2+a_5x^4+a_7x^6)
end{align}
]
设个 (g(x) = a_0 + a_2x + a_4x^2 + a_6x^3), 再设个 (h(x) = a_1 + a_3x + a_5x^2 + a_7x^3), 就有:
[f(x) = g(x^2) + xcdot h(x^2)
]
接下来是精彩的地方, 前面说的特殊点值要发挥作用了。
带入 n 次的某个单位根, 首先有:
[egin{align}
f(omega_n^k) &= g(omega_n^{2k}) + omega_n^k cdot h(omega_n^{2k})
\
&= g(omega_{n/2}^k) + omega_n^kcdot h(omega_{n/2}^k)
end{align}
]
然后有:
[egin{align}
f(omega_n^{k + n/2}) &= g(omega_{n}^{2k+n}) + omega_{n}^{k+n/2}cdot h(omega_n^{2k+n})
\
&= g(omega_{n/2}^k) - omega_n^kcdot h(omega_{n/2}^k)
end{align}
]
这个分治的结构就清晰可见了, 虽然本质什么的还不是很清楚, 但可以窥见一丝构造的痕迹。
IDFT
带入单位根的共轭复数 DFT 一下, 再把得到的东西除以 n 就行了。
代码
抄的学长的实现, 是有点优化的写法。目前没考虑到封装。
不用算 rev 的 FFT 真的那么 dio 吗?
#include<bits/stdc++.h>
using namespace std;
int rd() {
int x = 0;
char c = getchar();
while(c<'0' || c>'9') c=getchar();
while(c>='0' && c<='9') x=x*10+c-'0', c=getchar();
return x;
}
const int N = (1<<21)+ 233;
const double pi = acos(-1);
struct com {
double x, y;
com(double a, double b) : x(a), y(b) {
}
com() {
x=y=0;
}
const com operator+(const com rhs) const{
return com(x+rhs.x, y+rhs.y);
}
const com operator-(const com rhs) const{
return com(x-rhs.x, y-rhs.y);
}
const com operator*(const com rhs) const{
return com(x*rhs.x - y*rhs.y, x*rhs.y + y*rhs.x);
}
};
int n, m, rv[N];
com a[N], b[N];
void fft(com *a, int n, int type) {
for(int i=0; i<n; ++i) if(i<=rv[i]) swap(a[i], a[rv[i]]);
for(int m=2; m<=n; m<<=1) {
com w(cos(2 * pi / m), type * sin(2 * pi / m));
for(int i=0; i<n; i += m) {
com tmp = com(1, 0);
for(int j=0; j<(m>>1); ++j) {
com p = a[i+j], q = tmp * a[i+j+(m>>1)];
a[i+j] = p + q,
a[i+j+(m>>1)] = p - q;
tmp = tmp * w;
}
}
}
}
int main() {
n = rd()+1, m = rd() + 1;
for(int i=0; i<n; ++i) a[i].x = rd();
for(int i=0; i<m; ++i) b[i].x = rd();
for(m=n+m-1, n=1; n<m; n=n<<1);
for(int i=0; i<n; ++i) rv[i] = (rv[i>>1]>>1)|(i&1?(n>>1):0);
fft(a, n, 1);
fft(b, n, 1);
for(int i=0; i<n; ++i) a[i] = a[i] * b[i];
fft(a, n, -1);
for(int i=0; i<m; ++i) printf("%d ", (int)(a[i].x/n+0.5));
return 0;
}
现在正要加速学习多项式, 所以 NTT 就先背个版吧。
#include<bits/stdc++.h>
typedef long long LL;
using namespace std;
const int N = 3e6 + 23, mo = 998244353, g = 3;
int read() {
char c = getchar(); int x = 0;
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x;
}
int ksm(int a, int b) {
int res = 1;
for(; b; b=b>>1, a=((LL)a*a) % mo)
if(b & 1) res = (LL)res * a % mo;
return res;
}
const int ig = ksm(g, mo-2);
int n, m, a[N], b[N], rv[N];
void ntt(int *a, int n, int type) {
for(int i=0; i<n; ++i) if(i<rv[i]) swap(a[i], a[rv[i]]);
for(int m=2; m<=n; m<<=1) {
int w = ksm(type == 1 ? g : ig, (mo-1)/m);
for(int i = 0; i < n; i += m) {
int tmp = 1;
for(int j = 0; j < (m>>1); ++j) {
int p = a[i+j], q = (LL)tmp * a[i+j+(m>>1)] % mo;
a[i + j] = (p + q) % mo, a[i + j + (m>>1)] = (p - q + mo) % mo;
tmp = (LL)tmp * w % mo;
}
}
}
}
int main() {
n = read()+1, m = read()+1;
for(int i=0;i<n;++i) a[i]=read();
for(int i=0;i<m;++i) b[i]=read();
for(m=n+m-1,n=1;n<m;n<<=1);
for(int i=0;i<n;++i) rv[i] = (rv[i>>1]>>1)|((i&1)?(n>>1):0);
ntt(a, n, 1), ntt(b, n, 1);
for(int i = 0; i < n; ++i) a[i] = (LL)a[i] * b[i] % mo;
ntt(a, n, -1);
int inv = ksm(n, mo-2);
for(int i=0; i<m; ++i) cout << (LL)a[i] * inv % mo << ' ';
return 0;
}