2014.07.06 18:52
简介:
给定两个大小相同的方阵A和B,我们要计算AXB。方阵的大小是2的整次幂,比如2^k。对于这种特殊大小的方阵乘法,Strassen算法能够带来一定程度的加速,矩阵越大,加速效果越明显。
描述:
这个例子其实也是分治法的典型算法,通过矩阵分块进行乘法,然后合并结果。
不过我觉得这个例子,更突出的特点其实是“数学”二字。我想作者应该是想告诉我们,数学和计算机的关联,就在于用数学工具来思考问题可以把某些你想不清楚的东西一步一步推导清楚。不论是代数、几何、微积分(数学系的如果路过请勿鄙视,我们作为码农学好这些已经算是恪守本分了)、数值分析、科学计算等等,都不可能让你觉得很实用,但只要某一个瞬间能从这些交过学费的课程里得到些灵感,那就不算白学了。
根据矩阵乘法的基本定义,两个nXn的方阵相乘需要O(n^3)级别的算术运算。而Strassen算法的时间复杂度大概O(n^2.81),这样看起来不很显著的数量级差异需要较大的矩阵才能显示出性能的差异。因此Strassen算法的应用并不十分广泛,但思想的确是很有用的(至少Volker Strassen教授告诉我们学好代数是非常重要的)。初学这个算法时的感受,就和学习斐波那契数的对数级算法一样——数学要好好学,因为你学的那点数学实在很浅。
矩阵的分块我们在高中的代数里就学过,如果一个方阵的尺寸恰好是2的整次幂,那么我们可以一直进行四等分,直到尺寸为1。
于是矩阵A可以等分为四块儿:
┏━━┳━━┓
┃A1,1┃A1,2┃
┣━━╋━━┫
┃A2,1┃A2,2┃
┗━━┻━━┛
矩阵B也一样:
┏━━┳━━┓
┃B1,1┃B1,2┃
┣━━╋━━┫
┃B2,1┃B2,2┃
┗━━┻━━┛
本来2X2X2=8,如果对于这些子矩阵需要进行8次乘法,那就没有效率上的提升。但Herr Strassen设计了一种神奇的计算方法,使用7次子矩阵乘法,请看下面的公式:
1. M1=(A12-A22)(B21+B22)
2. M2=(A11+A22)(B11+B22)
3. M3=(A11-A21)(B11+B12)
4. M4=(A11+A12)B22
5. M5=A11(B12-B22)
6. M6=A22(B21-B11)
7. M7=(A21+A22)B11
这7个小矩阵M1~M7都是n/2大小的方阵。通过下面四个公式又可以将它们组合成最终结果C:
1. C11=M1+M2-M4+M6
2. C12=M4+M5
3. C21=M6+M7
4. C22=M2-M3+M5-M7
┏━━┳━━┓
┃C1,1┃C1,2┃
┣━━╋━━┫
┃C2,1┃C2,2┃
┗━━┻━━┛
七个式子是乘法,属于子问题的分解;四个式子是加减法,复杂度是严格O(n^2)的。两者的复杂度在下面会进行推导。
接下来开启高中模式,用数列知识来证明O(n^2.81)的复杂度(高考之前的我们应该是对数列问题最敏感的):
T(n)=7T(n/2)+O(n^2),前半部分表示7次子矩阵乘法,后半部分是4次加减法。
接下来把7T(n/2)这部分逐层拆成49T(n/4)、343T(n/8)...
T(n)=7(7T(n/4)+O(n^2/4))+O(n^2)
T(n)=49T(n/4)+O(n^2/4)+O(n^2)
...
T(n)=7^k * T(n/(2^k))+O((n/(2^k))^2)+...+O((n/2)^2)+O(n^2)
此处的k等于多少呢?等于log2n。为什么呢?因为2^k=n,这样拆了k次以后就不能再拆了,而k不一定要是整数。
再来看看后面的等比数列部分O(n^2)+O(n^2/4)+...+O(n^2/(4^k))。这部分的和S=O(n^2) * (4/3-1/(3 * 4^k))
由于2^k=n,于是4^k=n^2。那么S=O(n^2) * (4/3-1/(3 * n^2))=O(n^2)-O(1)=O(n^2)。
所以后面的等比数列总体的时间复杂度是O(n^2)级别的。
再来看看前面的部分7^k * T(n/(2^k))
T(n/(2^k))=O(1),表示这个公式展开了k次以后就不能再展开了。
7^k=7^(log2n)=7^(log7n/log72)=7^(log7n * log27)=(7^log7n)^log27=n^log27
又因为log27≈2.81,于是O(n^2.81)+O(n^2)=O(n^2.81)。
由此看来Strassen算法的O(n^2.81)的时间复杂度是由于化8次子问题为7次才得到的。
不过Srassen算法的局限性也是不少的:矩阵的尺寸限制的太严格,而且处理子问题的并行能力不太强。因此这个算法更多是出现在教材和ACM题库里,实际应用则没那么多。
实现:
1 // My implementation for Strassen Algorithm. Matrix size is limited to powers of 2 only. 2 #include <iostream> 3 #include <vector> 4 using namespace std; 5 6 void getSubmatrix(const vector<vector<int> > &m, vector<vector<int> > &sub, int part) 7 { 8 int n = (int)m.size(); 9 int i, j; 10 int n1 = n / 2; 11 12 int top = part / 2 * n1; 13 int left = part % 2 * n1; 14 for (i = 0; i < n1; ++i) { 15 for (j = 0; j < n1; ++j) { 16 sub[i][j] = m[top + i][left + j]; 17 } 18 } 19 } 20 21 void addMatrix(const vector<vector<int> > &a, const vector<vector<int> > &b, 22 vector<vector<int> > &c, int n) 23 { 24 int i, j; 25 26 for (i = 0; i < n; ++i) { 27 for (j = 0; j < n; ++j) { 28 c[i][j] = a[i][j] + b[i][j]; 29 } 30 } 31 } 32 33 void subtractMatrix(const vector<vector<int> > &a, const vector<vector<int> > &b, 34 vector<vector<int> > &c, int n) 35 { 36 int i, j; 37 38 for (i = 0; i < n; ++i) { 39 for (j = 0; j < n; ++j) { 40 c[i][j] = a[i][j] - b[i][j]; 41 } 42 } 43 } 44 45 void setSubmatrix(const vector<vector<int> > &sub, vector<vector<int> > &m, int part) 46 { 47 int n = (int)m.size(); 48 int i, j; 49 int n1 = n / 2; 50 51 int top = part / 2 * n1; 52 int left = part % 2 * n1; 53 for (i = 0; i < n1; ++i) { 54 for (j = 0; j < n1; ++j) { 55 m[top + i][left + j] = sub[i][j]; 56 } 57 } 58 } 59 60 void matrixMultiplicationRecursive(const vector<vector<int> > &a, 61 const vector<vector<int> > &b, vector<vector<int> > &c, int n) 62 { 63 if (n == 1) { 64 c[0][0] = a[0][0] * b[0][0]; 65 return; 66 } 67 68 int i; 69 int n1 = n / 2; 70 71 vector<vector<int> > aa[4]; 72 vector<vector<int> > bb[4]; 73 74 for (i = 0; i < 4; ++i) { 75 aa[i].resize(n1, vector<int>(n1)); 76 bb[i].resize(n1, vector<int>(n1)); 77 } 78 79 for (i = 0; i < 4; ++i) { 80 getSubmatrix(a, aa[i], i); 81 getSubmatrix(b, bb[i], i); 82 } 83 84 vector<vector<int> > x, y; 85 vector<vector<int> > m[7]; 86 87 x.resize(n1, vector<int>(n1)); 88 y.resize(n1, vector<int>(n1)); 89 for (i = 0; i < 7; ++i) { 90 m[i].resize(n1, vector<int>(n1)); 91 } 92 93 subtractMatrix(aa[1], aa[3], x, n1); 94 addMatrix(bb[2], bb[3], y, n1); 95 matrixMultiplicationRecursive(x, y, m[0], n1); 96 97 addMatrix(aa[0], aa[3], x, n1); 98 addMatrix(bb[0], bb[3], y, n1); 99 matrixMultiplicationRecursive(x, y, m[1], n1); 100 101 subtractMatrix(aa[0], aa[2], x, n1); 102 addMatrix(bb[0], bb[1], y, n1); 103 matrixMultiplicationRecursive(x, y, m[2], n1); 104 105 addMatrix(aa[0], aa[1], x, n1); 106 matrixMultiplicationRecursive(x, bb[3], m[3], n1); 107 108 subtractMatrix(bb[1], bb[3], y, n1); 109 matrixMultiplicationRecursive(aa[0], y, m[4], n1); 110 111 subtractMatrix(bb[2], bb[0], y, n1); 112 matrixMultiplicationRecursive(aa[3], y, m[5], n1); 113 114 addMatrix(aa[2], aa[3], x, n1); 115 matrixMultiplicationRecursive(x, bb[0], m[6], n1); 116 117 addMatrix(m[0], m[1], x, n1); 118 subtractMatrix(x, m[3], x, n1); 119 addMatrix(x, m[5], x, n1); 120 setSubmatrix(x, c, 0); 121 122 addMatrix(m[3], m[4], x, n1); 123 setSubmatrix(x, c, 1); 124 125 addMatrix(m[5], m[6], x, n1); 126 setSubmatrix(x, c, 2); 127 128 subtractMatrix(m[1], m[2], x, n1); 129 addMatrix(x, m[4], x, n1); 130 subtractMatrix(x, m[6], x, n1); 131 setSubmatrix(x, c, 3); 132 133 for (i = 0; i < 4; ++i) { 134 aa[i].clear(); 135 bb[i].clear(); 136 } 137 for (i = 0; i < 7; ++i) { 138 m[i].clear(); 139 } 140 x.clear(); 141 y.clear(); 142 } 143 144 void matrixMultiplication(const vector<vector<int> > &a, 145 const vector<vector<int> > &b, vector<vector<int> > &c) 146 { 147 int n = (int)a.size(); 148 149 matrixMultiplicationRecursive(a, b, c, n); 150 } 151 152 int main() 153 { 154 int n; 155 int i, j; 156 vector<vector<int> > a, b, c; 157 158 while (cin >> n && n > 0) { 159 a.resize(n, vector<int>(n)); 160 b.resize(n, vector<int>(n)); 161 c.resize(n, vector<int>(n)); 162 163 for (i = 0; i < n; ++i) { 164 for (j = 0; j < n; ++j) { 165 cin >> a[i][j]; 166 } 167 } 168 169 for (i = 0; i < n; ++i) { 170 for (j = 0; j < n; ++j) { 171 cin >> b[i][j]; 172 } 173 } 174 175 matrixMultiplication(a, b, c); 176 177 for (i = 0; i < n; ++i) { 178 for (j = 0; j < n; ++j) { 179 cout << c[i][j] << ' '; 180 } 181 cout << endl; 182 } 183 cout << endl; 184 185 a.clear(); 186 b.clear(); 187 c.clear(); 188 } 189 190 return 0; 191 }