原理:分块矩阵乘法,进行8次矩阵乘法,时间复杂度为 $ heta(n^3) = heta(n^{lg{8}}) $ , 改进后仅需要7次乘法, 时间复杂度为 $ heta(n^{lg{7}})$
具体推到见算法导论中利用主定理推导时间复杂度
def matrix_divide(A):
rows = len(A)
mid = rows // 2
A11 = [[0]*mid for _ in range(mid)]
A12 = [[0]*mid for _ in range(mid)]
A21 = [[0]*mid for _ in range(mid)]
A22 = [[0]*mid for _ in range(mid)]
for i in range(mid):
for j in range(mid):
A11[i][j] = A[i][j]
A12[i][j] = A[i][mid+j]
A21[i][j] = A[mid+i][j]
A22[i][j] = A[mid+i][mid+j]
return A11, A12, A21, A22
def matrix_add(A, B):
rows = len(A)
C = [[0]*rows for _ in range(rows)]
for i in range(rows):
for j in range(rows):
C[i][j] = A[i][j] + B[i][j]
return C
def matrix_sub(A, B):
rows = len(A)
C = [[0]*rows for _ in range(rows)]
for i in range(rows):
for j in range(rows):
C[i][j] = A[i][j] - B[i][j]
return C
def matrix_merge(C11, C12, C21, C22):
rows = len(C11)
n = rows * 2
C = [[0]*n for _ in range(n)]
for i in range(rows):
for j in range(rows):
C[i][j] = C11[i][j]
C[i][rows+j] = C12[i][j]
C[rows+i][j] = C21[i][j]
C[rows+i][rows+j] = C22[i][j]
return C
def strassen(A, B):
n = len(A)
C = [[0] for _ in range(n)]
if n == 1:
C[0][0] = A[0][0]*B[0][0]
return C
A11, A12, A21, A22 = matrix_divide(A)
B11, B12, B21, B22 = matrix_divide(B)
S1 = matrix_sub(B12, B22)
S2 = matrix_add(A11, A12)
S3 = matrix_add(A21, A22)
S4 = matrix_sub(B21, B11)
S5 = matrix_add(A11, A22)
S6 = matrix_add(B11, B22)
S7 = matrix_sub(A12, A22)
S8 = matrix_add(B21, B22)
S9 = matrix_sub(A11, A21)
S10 = matrix_add(B11, B12)
P1 = strassen(A11, S1)
P2 = strassen(S2, B22)
P3 = strassen(S3, B11)
P4 = strassen(A22, S4)
P5 = strassen(S5, S6)
P6 = strassen(S7, S8)
P7 = strassen(S9, S10)
C11 = matrix_add(P5, matrix_sub(P4, matrix_sub(P2, P6)))
C12 = matrix_add(P1, P2)
C21 = matrix_add(P3, P4)
C22 = matrix_add(P5, matrix_sub(P1, matrix_add(P3, P7)))
return matrix_merge(C11, C12, C21, C22)
def main():
A = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
B = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]
C = strassen(A, B)
print(C)
if __name__ == '__main__':
main()