• Algorithms


    问题:
        求解矩阵乘法  C = A * B, 已知 A, B, C 均为 N x N 的方阵, 切 N 为 2 的幂(为简化问题). 
            A = [[A11, A12], [A21, A22]]
            B = [[B11, B12], [B21, B22]]
            C = [[C11, C12], [C21, C22]]
            
            则(矩阵乘法运算法则):
                C11 = A11 * B11 + A12 * B21
                C12 = A11 * B12 + A12 * B22
                C21 = A21 * B11 + A22 * B21
                C22 = A21 * B12 + A22 * B22
    N x N 方阵的常规计算方法:
    
    def squre_matrix_multiply(A, B):
        n = len(A)
        # let c to be a new n x n matrix
        c = [[0 for y in range(n)] for x in range(n)]
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    c[i][j] = c[i][j] + A[i][k] * B[k][j]
    
        print(c)
    
    if __name__ == '__main__':
        A = [[2,1],[3,6]]
        B = [[3,4],[2,2]]
        squre_matrix_multiply(A,B)
    
    结果:
        [[8, 10], [21,24]]
    通过分治思想求解:
    分治思想: 将 N x N 划分为
    4 个 N/2 * N/2 的子矩阵乘积之和.    def squre_matrix_multiply_recursive(A, B): try: n = len(A[0]) except TypeError: n = 1 # let c to be a new nxn matrix c = [[0 for x in range(n)] for y in range(n)] if n == 1: c = [[0],[0]] c[0][0] = A[0] * B[0] else: # partition A, B and C c[0][0] = squre_matrix_multiply_recursive([A[0][0]], [B[0][0]]) + squre_matrix_multiply_recursive([A[0][1]], [B[1][0]]) c[0][1] = squre_matrix_multiply_recursive([A[0][0]], [B[0][1]]) + squre_matrix_multiply_recursive([A[0][1]], [B[1][1]]) c[1][0] = squre_matrix_multiply_recursive([A[1][0]], [B[0][0]]) + squre_matrix_multiply_recursive([A[1][1]], [B[1][0]]) c[1][1] = squre_matrix_multiply_recursive([A[1][0]], [B[0][1]]) + squre_matrix_multiply_recursive([A[1][1]], [B[1][1]]) # process the res res = [[0 for x in range(n)] for y in range(n)] for i in range(n): for j in range(n): res[i][j] = sum_list(c[i][j]) return res def sum_list(A): # A: [[6], [0], [2], [0]] res = 0 try: for i in A: res += i[0] except TypeError: res += A return res

    if __name__ == '__main__':
        A = [[2,1],[3,6]]
        B = [[3,4],[2,2]]
    
        print(squre_matrix_multiply_recursive(A,B))
    结果: 
      [[
    8, 10], [21, 24]]
    Strassen 算法:
        Strassen 算法只递归进行 7 次运算 N/2 x N/2 矩阵的乘法(分治算法递归运算8次) . 
        
        1.  创建10个 N/2 x N/2 的矩阵 S1, S2, …, S10.
            S1 = B12 - B22
            S2 = A11 + A12
            S3 = A21 + A22
            S4 = B21 - B11
            S5 = A11 + A22
            S6 = B11 + B22
            S7 = A12 - A22
            S8 = B21 + B22
            S9 = A11 - A21
            S10 = B11 - B12
            
        2.  通过 S1 … S10 构建 P1, P2, …, P7
            P1 = A11 * S1 = A11 * B12 - A11 * B22
            P2 = S2 * B22 = A11 * B22 + A12 * B22
            P3 = S3 * B11 = A21 * B11 + A22 * B1
            P4 = A22 * S4 = A22 * B21 - A22 * B11
            P5 = S5 * S6 = A11 * B11 + A11 * B22 + A22 * B11 + A22 * B22
            P6 = S7 * S8 = A12 * B21 + A12 * B22 - A22 * B21 - A22 * B22
            P7 = S9 * 10 = A11 * B11 + A11 * B12 - A21 * B11 - A21 * B12
            
        3.  通过上面步骤构建的 P1 … P7 来计算 C
            C11 = P4 + P5 + P6 - P2
            C12 = P1 + P2
            C21 = P3 + P4
            C22 = P1 + P5 +P7 - P3
    
    def strassn(A, B):
        try:
            n = len(A[0])
        except TypeError:
            n = 1
        # let c to be a new nxn matrix
        c = [[0 for x in range(n)] for y in range(n)]
        if n == 1:
            c[0][0] = A[0] * B[0]
    
        # partition A, B and C        
        else:
            # only suit for 2X2 matrix
            # step 1
            s1 = B[0][1] - B[1][1]      
            s2 = A[0][0] + A[0][1]
            s3 = A[1][0] + A[1][1]
            s4 = B[1][0] - B[0][0]
            s5 = A[0][0] + A[1][1]
            s6 = B[0][0] + B[1][1]
            s7 = A[0][1] - A[1][1]
            s8 = B[1][0] + B[1][1]
            s9 = A[0][0] - A[1][0]
            s10 = B[0][0] + B[0][1]
    
            # step 2
            p1 = A[0][0] * s1
            p2 = s2 * B[1][1]
            p3 = s3 * B[0][0]
            p4 = A[1][1] * s4
            p5 = s5 * s6
            p6 = s7 * s8
            p7 = s9 * s10
    
            # step 3
            c[0][0] = p5 + p4 - p2 + p6
            c[0][1] = p1 + p2
            c[1][0] = p3 + p4
            c[1][1] = p5 + p1 - p3 - p7
    
        return c
    
    if __name__ == '__main__':
        A = [[2,1],[3,6]]
        B = [[3,4],[2,2]]
    
        print(strassn(A, B))
    结果:
        [[8, 10], [21, 24]]

    Reference, 

      1. Introduction to algorithms

    strassn(A, B)
  • 相关阅读:
    Leetcode 233 Number of Digit One
    获取各种常见形状的位图
    关于编程
    LintCode-第k大元素
    基于IBM Bluemix的数据缓存应用实例
    LeakCanary:简单粗暴的内存泄漏检測工具
    MFC,C++,VC++,VS2010 之间究竟是什么关系
    我对高考考场制度(比方是否同意迟到、忘带考证、上厕所)优化的点滴思考,不一定非常有道理
    ural 1989(树状数组+多项式hash)
    TI C66x DSP 系统events及其应用
  • 原文地址:https://www.cnblogs.com/zzyzz/p/12862998.html
Copyright © 2020-2023  润新知