• FFT 模板


    $DeclareMathOperator{ ev}{rev}$

    FFT(Fast Fourier Transform),确切地说应该称之为 FDFT(Fast Discrete Fourier Transform),因为 FFT 是为了解决 DFT 问题而设计的一种快速算法。在深入讨论之前,有必要特别指出这一点。

    DFT 问题:

    给定复数域上的 $n-1$ 次多项式 $A(x)$ 的系数表示(coefficient representation)$(a_0, a_1,dots, a_{n-1})$,求 $A(x)$ 的某个点-值表示(point-value representation):

    [((x_0, y_0), (x_1, y_1), (x_2, y_2), dots, (x_{n-1}, y_{n-1}))]

    FFT 的非递归(迭代)实现 Version I 手写 Complex 类(《算法导论》)

     1 #include <bits/stdc++.h>
     2 #define rep(i, l, r) for(int i=l; i<r; i++)
     3 using namespace std;
     4 const double PI(acos(-1));
     5 
     6 struct Complex{
     7     double r, i;
     8     Complex(double r, double i):r(r), i(i){}
     9     Complex(int n):r(cos(2*PI/n)), i(sin(2*PI/n)){}    //!!error-prone
    10     Complex():r(0), i(0){}    //default constructor
    11     Complex &operator*=(const Complex &a){
    12         double R=r*a.r-i*a.i, I=r*a.i+a.r*i;
    13         r=R, i=I;
    14         return *this;
    15     }
    16     Complex operator+(const Complex a){
    17         return Complex(r+a.r, i+a.i);
    18     }
    19     Complex operator-(const Complex a){
    20         return Complex(r-a.r, i-a.i);
    21     }
    22     Complex operator*(const Complex a){
    23         return Complex(r*a.r-i*a.i, r*a.i+a.r*i);
    24     }
    25     void out(){
    26         cout<<r<<' '<<i<<endl;
    27     }
    28 };
    29 
    30 const int N(1<<17);
    31 int ans[N];
    32 Complex a[N], b[N];
    33 char s[N], t[N];
    34 
    35 void bit_reverse_swap(Complex *a, int n){
    36     for(int i=1, j=n>>1, k; i<n-1; i++){
    37         if(i < j) swap(a[i],a[j]);
    38         //tricky
    39         for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
    40         j+=k;
    41     }
    42 }
    43 
    44 void FFT(Complex* a, int n, int t){
    45     bit_reverse_swap(a, n);
    46     for(int i=2; i<=n; i<<=1){
    47         Complex wi(i*t);
    48         for(int j=0; j<n; j+=i){
    49             Complex w(1, 0);
    50             for(int k=j, h=i>>1; k<j+h; k++){
    51                 Complex t=w*a[k+h], u=a[k];
    52                 a[k]=u+t;
    53                 a[k+h]=u-t;
    54                 w*=wi;
    55             }
    56         }
    57     }
    58     if(t==-1) rep(i, 0, n) a[i].r/=n;    //!!error-prone
    59 }
    60 
    61 int trans(int x){
    62     int i=0;
    63     for(; x>1<<i; i++);
    64     return 1<<i;
    65 }
    66 
    67 int main(){
    68     for(; ~scanf("%s%s", s, t); ){
    69         int n=strlen(s), m=strlen(t), l=trans(n+m-1);
    70         rep(i, 0, n) a[i]=Complex(s[n-1-i]-'0', 0);
    71         rep(i, n, l) a[i]=Complex(0, 0);
    72         rep(i, 0, m) b[i]=Complex(t[m-1-i]-'0', 0);
    73         rep(i, m, l) b[i]=Complex(0, 0);
    74 
    75         FFT(a, l, 1), FFT(b, l, 1);
    76         rep(i, 0, l) a[i]*=b[i];
    77         FFT(a, l, -1);
    78         rep(i, 0, l) ans[i]=(int)(a[i].r+0.5); ans[l]=0;    //error-prone
    79         rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
    80         int c=l;
    81         for(;c && !ans[c]; --c);
    82         for(; ~c; putchar(ans[c--]+'0'));    //error-prone
    83         puts("");
    84     }
    85     return 0;
    86 }

    Comment:

    1. 此代码是为 HDU1402 写的。代码中,凡注释 error-prone 处,都应特别小心。我犯的最傻逼的错误是第9行,应当是2*PI,我写成PI了。

    2. FFT的数值稳定性(精度)问题,还有待考虑。(UPD)多次做多项式乘法时,精度损失较快,这时将 double 换成 long double 可缓解精度损失。


    Comment:

    1. bit_reverse_swap()函数是对算法导论上的bit_reverse_copy()的改进,将下标互为bit-reverse 的两元素互换位置,就免去了 copy 所需的空间。

    2.bit_revrese_copy()不太好懂,需要一点解释:

    1 void bit_reverse_swap(Complex *a, int n){
    2     for(int i=1, j=n>>1, k; i<n-1; i++){
    3         if(i < j) swap(a[i],a[j]);
    4         //tricky
    5         for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
    6         j+=k;
    7     }
    8 }

    将$i$的bit-reverse记作$ ev(i)$。

    (i). 由于 $ ev(0)=1, ev(n-1)=n-1$($n$ 是 $2$ 的幂),所以第 2 行的主循环可令 $i$ 从 $1$ 循环到 $n-2$。同时 $j$ 从 $ ev(1)= frac{n}{2}$,“循环”到 $ ev(n-2)$ 。

    (ii). 第 3 行的判断 if(i < j) 避免了重复交换

    (iii).第 5 行的循环的作用就是将 $j$ 从 $ ev(i)$ 变成 $ ev(i+1)$:

    首先应当注意到,$i$ 的最低位恰是 $rev(i)$ 的最高位。若 $ ev(i)$ 的最高位是 $0$ 那么 $ ev(i+1)$ 就是 $ ev(i)+ frac{n}{2}$,否则,$i$ 加上 $1$ 后,最低位将变成 $0$,并且向高一位进 $1$ 。相应的,$ ev(i+1)$ 的最高位应置 $0$(即代码中的 j-=k),并且向低一位"进“ $1$(对应代码中的 k>>=1)。这样从高位往低位检查,遇到 $1$(对应代码中的条件 j>=k)就进位,遇到 $0$ 就退出循环。
    3. 我写代码时把第 58 行的 == 写成了 =,结果DEBUG 一个多小时。。。

    Version II: 用 C++ 标准库中的 complex<double> 类,代码短一些,但也会慢一些:

     1 #include <bits/stdc++.h>
     2 #define rep(i, l, r) for(int i=l; i<r; i++)
     3 using namespace std;
     4 const double PI(acos(-1));
     5 typedef complex<double> C;
     6 
     7 const int N(1<<17);
     8 int ans[N];
     9 C a[N], b[N];
    10 char s[N], t[N];
    11 
    12 void bit_reverse_swap(C *a, int n){
    13     for(int i=1, j=n>>1, k; i<n-1; i++){
    14         if(i < j) swap(a[i],a[j]);
    15         //tricky
    16         for(k=n>>1; j>=k; j-=k, k>>=1);    //inspect the highest "1"
    17         j+=k;
    18     }
    19 }
    20 
    21 void FFT(C* a, int n, int t){
    22     bit_reverse_swap(a, n);
    23     for(int i=2; i<=n; i<<=1){
    24         C wi(cos(t*2*PI/i), sin(t*2*PI/i));
    25         for(int j=0; j<n; j+=i){
    26             C w(1);
    27             for(int k=j, h=i>>1; k<j+h; k++){
    28                 C t=w*a[k+h], u=a[k];
    29                 a[k]=u+t;
    30                 a[k+h]=u-t;
    31                 w*=wi;
    32             }
    33         }
    34     }
    35     if(t==-1) rep(i, 0, n) a[i]/=n;    //!!error-prone: typo ==/=
    36 }
    37 
    38 int trans(int x){
    39     int i=0;
    40     for(; x>1<<i; i++);
    41     return 1<<i;
    42 }
    43 
    44 int main(){
    45     for(; ~scanf("%s%s", s, t); ){
    46         int n=strlen(s), m=strlen(t), l=trans(n+m-1);
    47         rep(i, 0, n) a[i]=C(s[n-1-i]-'0');
    48         rep(i, n, l) a[i]=C(0);
    49         rep(i, 0, m) b[i]=C(t[m-1-i]-'0');
    50         rep(i, m, l) b[i]=C(0);
    51 
    52         FFT(a, l, 1), FFT(b, l, 1);
    53         rep(i, 0, l) a[i]*=b[i];
    54         FFT(a, l, -1);
    55         rep(i, 0, l) ans[i]=(int)(a[i].real()+0.5); ans[l]=0;    //error-prone
    56         rep(i, 0, l) ans[i+1]+=ans[i]/10, ans[i]%=10;
    57         int p=l;
    58         for(;p && !ans[p]; --p);
    59         for(; ~p; putchar(ans[p--]+'0'));    //error-prone
    60         puts("");
    61     }
    62     return 0;
    63 }

     

  • 相关阅读:
    为什么重写equals方法还要重写hashcode方法?
    提高数据库处理查询速度
    ibatis缓存初探(1)
    java四种数组排序
    前台将勾选的多个属性放到一个value里面,是一个字符串,传到后台
    Apache与Tomcat整合
    web服务器和应用服务器概念比较
    ibaits与spring整合的心得
    spring3.0的jar包详解
    JAVA:23种设计模式详解(转)2
  • 原文地址:https://www.cnblogs.com/Patt/p/5503322.html
Copyright © 2020-2023  润新知