问题: 求解矩阵乘法 C = A * B, 已知 A, B, C 均为 N x N 的方阵, 切 N 为 2 的幂(为简化问题). A = [[A11, A12], [A21, A22]] B = [[B11, B12], [B21, B22]] C = [[C11, C12], [C21, C22]] 则(矩阵乘法运算法则): C11 = A11 * B11 + A12 * B21 C12 = A11 * B12 + A12 * B22 C21 = A21 * B11 + A22 * B21 C22 = A21 * B12 + A22 * B22
N x N 方阵的常规计算方法: def squre_matrix_multiply(A, B): n = len(A) # let c to be a new n x n matrix c = [[0 for y in range(n)] for x in range(n)] for i in range(n): for j in range(n): for k in range(n): c[i][j] = c[i][j] + A[i][k] * B[k][j] print(c) if __name__ == '__main__': A = [[2,1],[3,6]] B = [[3,4],[2,2]] squre_matrix_multiply(A,B) 结果: [[8, 10], [21,24]]
通过分治思想求解:
分治思想: 将 N x N 划分为 4 个 N/2 * N/2 的子矩阵乘积之和. def squre_matrix_multiply_recursive(A, B): try: n = len(A[0]) except TypeError: n = 1 # let c to be a new nxn matrix c = [[0 for x in range(n)] for y in range(n)] if n == 1: c = [[0],[0]] c[0][0] = A[0] * B[0] else: # partition A, B and C c[0][0] = squre_matrix_multiply_recursive([A[0][0]], [B[0][0]]) + squre_matrix_multiply_recursive([A[0][1]], [B[1][0]]) c[0][1] = squre_matrix_multiply_recursive([A[0][0]], [B[0][1]]) + squre_matrix_multiply_recursive([A[0][1]], [B[1][1]]) c[1][0] = squre_matrix_multiply_recursive([A[1][0]], [B[0][0]]) + squre_matrix_multiply_recursive([A[1][1]], [B[1][0]]) c[1][1] = squre_matrix_multiply_recursive([A[1][0]], [B[0][1]]) + squre_matrix_multiply_recursive([A[1][1]], [B[1][1]]) # process the res res = [[0 for x in range(n)] for y in range(n)] for i in range(n): for j in range(n): res[i][j] = sum_list(c[i][j]) return res def sum_list(A): # A: [[6], [0], [2], [0]] res = 0 try: for i in A: res += i[0] except TypeError: res += A return res
if __name__ == '__main__':
A = [[2,1],[3,6]]
B = [[3,4],[2,2]]
print(squre_matrix_multiply_recursive(A,B))
结果:
[[8, 10], [21, 24]]
Strassen 算法: Strassen 算法只递归进行 7 次运算 N/2 x N/2 矩阵的乘法(分治算法递归运算8次) . 1. 创建10个 N/2 x N/2 的矩阵 S1, S2, …, S10. S1 = B12 - B22 S2 = A11 + A12 S3 = A21 + A22 S4 = B21 - B11 S5 = A11 + A22 S6 = B11 + B22 S7 = A12 - A22 S8 = B21 + B22 S9 = A11 - A21 S10 = B11 - B12 2. 通过 S1 … S10 构建 P1, P2, …, P7 P1 = A11 * S1 = A11 * B12 - A11 * B22 P2 = S2 * B22 = A11 * B22 + A12 * B22 P3 = S3 * B11 = A21 * B11 + A22 * B1 P4 = A22 * S4 = A22 * B21 - A22 * B11 P5 = S5 * S6 = A11 * B11 + A11 * B22 + A22 * B11 + A22 * B22 P6 = S7 * S8 = A12 * B21 + A12 * B22 - A22 * B21 - A22 * B22 P7 = S9 * 10 = A11 * B11 + A11 * B12 - A21 * B11 - A21 * B12 3. 通过上面步骤构建的 P1 … P7 来计算 C C11 = P4 + P5 + P6 - P2 C12 = P1 + P2 C21 = P3 + P4 C22 = P1 + P5 +P7 - P3 def strassn(A, B): try: n = len(A[0]) except TypeError: n = 1 # let c to be a new nxn matrix c = [[0 for x in range(n)] for y in range(n)] if n == 1: c[0][0] = A[0] * B[0] # partition A, B and C else: # only suit for 2X2 matrix # step 1 s1 = B[0][1] - B[1][1] s2 = A[0][0] + A[0][1] s3 = A[1][0] + A[1][1] s4 = B[1][0] - B[0][0] s5 = A[0][0] + A[1][1] s6 = B[0][0] + B[1][1] s7 = A[0][1] - A[1][1] s8 = B[1][0] + B[1][1] s9 = A[0][0] - A[1][0] s10 = B[0][0] + B[0][1] # step 2 p1 = A[0][0] * s1 p2 = s2 * B[1][1] p3 = s3 * B[0][0] p4 = A[1][1] * s4 p5 = s5 * s6 p6 = s7 * s8 p7 = s9 * s10 # step 3 c[0][0] = p5 + p4 - p2 + p6 c[0][1] = p1 + p2 c[1][0] = p3 + p4 c[1][1] = p5 + p1 - p3 - p7 return c if __name__ == '__main__': A = [[2,1],[3,6]] B = [[3,4],[2,2]] print(strassn(A, B))
结果:
[[8, 10], [21, 24]]
Reference,
1. Introduction to algorithms
strassn(A, B)