• 《数据结构与算法分析:C语言描述》复习——第十章“算法设计技巧”——Strassen矩阵乘法


    2014.07.06 18:52

    简介:

      给定两个大小相同的方阵A和B,我们要计算AXB。方阵的大小是2的整次幂,比如2^k。对于这种特殊大小的方阵乘法,Strassen算法能够带来一定程度的加速,矩阵越大,加速效果越明显。

    描述:

      这个例子其实也是分治法的典型算法,通过矩阵分块进行乘法,然后合并结果。

      不过我觉得这个例子,更突出的特点其实是“数学”二字。我想作者应该是想告诉我们,数学和计算机的关联,就在于用数学工具来思考问题可以把某些你想不清楚的东西一步一步推导清楚。不论是代数、几何、微积分(数学系的如果路过请勿鄙视,我们作为码农学好这些已经算是恪守本分了)、数值分析、科学计算等等,都不可能让你觉得很实用,但只要某一个瞬间能从这些交过学费的课程里得到些灵感,那就不算白学了。

      根据矩阵乘法的基本定义,两个nXn的方阵相乘需要O(n^3)级别的算术运算。而Strassen算法的时间复杂度大概O(n^2.81),这样看起来不很显著的数量级差异需要较大的矩阵才能显示出性能的差异。因此Strassen算法的应用并不十分广泛,但思想的确是很有用的(至少Volker Strassen教授告诉我们学好代数是非常重要的)。初学这个算法时的感受,就和学习斐波那契数的对数级算法一样——数学要好好学,因为你学的那点数学实在很浅。

      矩阵的分块我们在高中的代数里就学过,如果一个方阵的尺寸恰好是2的整次幂,那么我们可以一直进行四等分,直到尺寸为1。

      于是矩阵A可以等分为四块儿:

        ┏━━┳━━┓

        ┃A1,1┃A1,2┃

        ┣━━╋━━┫

        ┃A2,1┃A2,2┃

        ┗━━┻━━┛

      矩阵B也一样:

        ┏━━┳━━┓

        ┃B1,1┃B1,2┃

        ┣━━╋━━┫

        ┃B2,1┃B2,2┃

        ┗━━┻━━┛

       本来2X2X2=8,如果对于这些子矩阵需要进行8次乘法,那就没有效率上的提升。但Herr Strassen设计了一种神奇的计算方法,使用7次子矩阵乘法,请看下面的公式:

        1. M1=(A12-A22)(B21+B22)

        2. M2=(A11+A22)(B11+B22)

        3. M3=(A11-A21)(B11+B12)

        4. M4=(A11+A12)B22

        5. M5=A11(B12-B22)

        6. M6=A22(B21-B11)

        7. M7=(A21+A22)B11

      这7个小矩阵M1~M7都是n/2大小的方阵。通过下面四个公式又可以将它们组合成最终结果C:

        1. C11=M1+M2-M4+M6

        2. C12=M4+M5

        3. C21=M6+M7

        4. C22=M2-M3+M5-M7

        ┏━━┳━━┓

        ┃C1,1┃C1,2┃

        ┣━━╋━━┫

        ┃C2,1┃C2,2┃

        ┗━━┻━━┛

      七个式子是乘法,属于子问题的分解;四个式子是加减法,复杂度是严格O(n^2)的。两者的复杂度在下面会进行推导。

      接下来开启高中模式,用数列知识来证明O(n^2.81)的复杂度(高考之前的我们应该是对数列问题最敏感的):

        T(n)=7T(n/2)+O(n^2),前半部分表示7次子矩阵乘法,后半部分是4次加减法。

        接下来把7T(n/2)这部分逐层拆成49T(n/4)、343T(n/8)...

        T(n)=7(7T(n/4)+O(n^2/4))+O(n^2)

        T(n)=49T(n/4)+O(n^2/4)+O(n^2)

        ...

        T(n)=7^k * T(n/(2^k))+O((n/(2^k))^2)+...+O((n/2)^2)+O(n^2)

        此处的k等于多少呢?等于log2n。为什么呢?因为2^k=n,这样拆了k次以后就不能再拆了,而k不一定要是整数。

        再来看看后面的等比数列部分O(n^2)+O(n^2/4)+...+O(n^2/(4^k))。这部分的和S=O(n^2) * (4/3-1/(3 * 4^k))

        由于2^k=n,于是4^k=n^2。那么S=O(n^2) * (4/3-1/(3 * n^2))=O(n^2)-O(1)=O(n^2)。

        所以后面的等比数列总体的时间复杂度是O(n^2)级别的。

        再来看看前面的部分7^k * T(n/(2^k))

        T(n/(2^k))=O(1),表示这个公式展开了k次以后就不能再展开了。

        7^k=7^(log2n)=7^(log7n/log72)=7^(log7n * log27)=(7^log7n)^log27=n^log27

        又因为log27≈2.81,于是O(n^2.81)+O(n^2)=O(n^2.81)。

      由此看来Strassen算法的O(n^2.81)的时间复杂度是由于化8次子问题为7次才得到的。

      不过Srassen算法的局限性也是不少的:矩阵的尺寸限制的太严格,而且处理子问题的并行能力不太强。因此这个算法更多是出现在教材和ACM题库里,实际应用则没那么多。

    实现:

      1 // My implementation for Strassen Algorithm. Matrix size is limited to powers of 2 only.
      2 #include <iostream>
      3 #include <vector>
      4 using namespace std;
      5 
      6 void getSubmatrix(const vector<vector<int> > &m, vector<vector<int> > &sub, int part)
      7 {
      8     int n = (int)m.size();
      9     int i, j;
     10     int n1 = n / 2;
     11     
     12     int top = part / 2 * n1;
     13     int left = part % 2 * n1;
     14     for (i = 0; i < n1; ++i) {
     15         for (j = 0; j < n1; ++j) {
     16             sub[i][j] = m[top + i][left + j];
     17         }
     18     }
     19 }
     20 
     21 void addMatrix(const vector<vector<int> > &a, const vector<vector<int> > &b, 
     22     vector<vector<int> > &c, int n)
     23 {
     24     int i, j;
     25     
     26     for (i = 0; i < n; ++i) {
     27         for (j = 0; j < n; ++j) {
     28             c[i][j] = a[i][j] + b[i][j];
     29         }
     30     }
     31 }
     32 
     33 void subtractMatrix(const vector<vector<int> > &a, const vector<vector<int> > &b, 
     34     vector<vector<int> > &c, int n)
     35 {
     36     int i, j;
     37     
     38     for (i = 0; i < n; ++i) {
     39         for (j = 0; j < n; ++j) {
     40             c[i][j] = a[i][j] - b[i][j];
     41         }
     42     }
     43 }
     44 
     45 void setSubmatrix(const vector<vector<int> > &sub, vector<vector<int> > &m, int part)
     46 {
     47     int n = (int)m.size();
     48     int i, j;
     49     int n1 = n / 2;
     50     
     51     int top = part / 2 * n1;
     52     int left = part % 2 * n1;
     53     for (i = 0; i < n1; ++i) {
     54         for (j = 0; j < n1; ++j) {
     55             m[top + i][left + j] = sub[i][j];
     56         }
     57     }
     58 }
     59 
     60 void matrixMultiplicationRecursive(const vector<vector<int> > &a, 
     61     const vector<vector<int> > &b, vector<vector<int> > &c, int n)
     62 {
     63     if (n == 1) {
     64         c[0][0] = a[0][0] * b[0][0];
     65         return;
     66     }
     67     
     68     int i;
     69     int n1 = n / 2;
     70     
     71     vector<vector<int> > aa[4];
     72     vector<vector<int> > bb[4];
     73     
     74     for (i = 0; i < 4; ++i) {
     75         aa[i].resize(n1, vector<int>(n1));
     76         bb[i].resize(n1, vector<int>(n1));
     77     }
     78     
     79     for (i = 0; i < 4; ++i) {
     80         getSubmatrix(a, aa[i], i);
     81         getSubmatrix(b, bb[i], i);
     82     }
     83     
     84     vector<vector<int> > x, y;
     85     vector<vector<int> > m[7];
     86     
     87     x.resize(n1, vector<int>(n1));
     88     y.resize(n1, vector<int>(n1));
     89     for (i = 0; i < 7; ++i) {
     90         m[i].resize(n1, vector<int>(n1));
     91     }
     92     
     93     subtractMatrix(aa[1], aa[3], x, n1);
     94     addMatrix(bb[2], bb[3], y, n1);
     95     matrixMultiplicationRecursive(x, y, m[0], n1);
     96     
     97     addMatrix(aa[0], aa[3], x, n1);
     98     addMatrix(bb[0], bb[3], y, n1);
     99     matrixMultiplicationRecursive(x, y, m[1], n1);
    100     
    101     subtractMatrix(aa[0], aa[2], x, n1);
    102     addMatrix(bb[0], bb[1], y, n1);
    103     matrixMultiplicationRecursive(x, y, m[2], n1);
    104     
    105     addMatrix(aa[0], aa[1], x, n1);
    106     matrixMultiplicationRecursive(x, bb[3], m[3], n1);
    107     
    108     subtractMatrix(bb[1], bb[3], y, n1);
    109     matrixMultiplicationRecursive(aa[0], y, m[4], n1);
    110     
    111     subtractMatrix(bb[2], bb[0], y, n1);
    112     matrixMultiplicationRecursive(aa[3], y, m[5], n1);
    113     
    114     addMatrix(aa[2], aa[3], x, n1);
    115     matrixMultiplicationRecursive(x, bb[0], m[6], n1);
    116     
    117     addMatrix(m[0], m[1], x, n1);
    118     subtractMatrix(x, m[3], x, n1);
    119     addMatrix(x, m[5], x, n1);
    120     setSubmatrix(x, c, 0);
    121     
    122     addMatrix(m[3], m[4], x, n1);
    123     setSubmatrix(x, c, 1);
    124     
    125     addMatrix(m[5], m[6], x, n1);
    126     setSubmatrix(x, c, 2);
    127     
    128     subtractMatrix(m[1], m[2], x, n1);
    129     addMatrix(x, m[4], x, n1);
    130     subtractMatrix(x, m[6], x, n1);
    131     setSubmatrix(x, c, 3);
    132 
    133     for (i = 0; i < 4; ++i) {
    134         aa[i].clear();
    135         bb[i].clear();
    136     }
    137     for (i = 0; i < 7; ++i) {
    138         m[i].clear();
    139     }
    140     x.clear();
    141     y.clear();
    142 }
    143 
    144 void matrixMultiplication(const vector<vector<int> > &a, 
    145     const vector<vector<int> > &b, vector<vector<int> > &c)
    146 {
    147     int n = (int)a.size();
    148     
    149     matrixMultiplicationRecursive(a, b, c, n);
    150 }
    151 
    152 int main()
    153 {
    154     int n;
    155     int i, j;
    156     vector<vector<int> > a, b, c;
    157     
    158     while (cin >> n && n > 0) {
    159         a.resize(n, vector<int>(n));
    160         b.resize(n, vector<int>(n));
    161         c.resize(n, vector<int>(n));
    162         
    163         for (i = 0; i < n; ++i) {
    164             for (j = 0; j < n; ++j) {
    165                 cin >> a[i][j];
    166             }
    167         }
    168 
    169         for (i = 0; i < n; ++i) {
    170             for (j = 0; j < n; ++j) {
    171                 cin >> b[i][j];
    172             }
    173         }
    174         
    175         matrixMultiplication(a, b, c);
    176 
    177         for (i = 0; i < n; ++i) {
    178             for (j = 0; j < n; ++j) {
    179                 cout << c[i][j] << ' ';
    180             }
    181             cout << endl;
    182         }
    183         cout << endl;
    184 
    185         a.clear();
    186         b.clear();
    187         c.clear();
    188     }
    189     
    190     return 0;
    191 }
  • 相关阅读:
    js 跳转链接
    reg.test is not a function 报错
    html中button自动提交表单?
    mysql主从复制及双主复制
    nginx反向代理后端web服务器记录客户端ip地址
    mysql多实例-主从复制安装
    LVS+Keepalived高可用负载均衡集群架构实验-01
    debug调试
    常用网站总结
    项目部署
  • 原文地址:https://www.cnblogs.com/zhuli19901106/p/3828519.html
Copyright © 2020-2023  润新知