• Conquer and Divide经典例子之Strassen算法解决大型矩阵的相乘


    通过汉诺塔问题理解递归的精髓中我讲解了怎么把一个复杂的问题一步步recursively划分了成简单显而易见的小问题。其实这个解决问题的思路就是算法中常用的divide and conquer, 这篇日志通过解决矩阵的乘法,来了解另外一个基本divide and conque思想的strassen算法。

    矩阵A乘以B等于X, 则Xij = 
    注意左乘右乘的区别,AB 与BA是不同的。
    如果r = 1, 直接就是两个数的相乘。
    如果r = 2, 例如
    X = 
    [ 1, 2; 
      3, 4];
    Y = 
    [ 2, 3;
     4, 5];
    R = XY的计算十分简单,但是如果r很大,耗时是O(r^3)。为了简化,可以把X, Y各自划分成2X2的矩阵,每一个元素其实是有n/2行的矩阵
    (注:这里仅讲解行数等于列数的情况。)

    X = 
    [A, B;
    C, D];

    Y = 
    [E, F;
    G, H]

    所以XY =[
    AE+BG, AF+BH;
    CE+DG, CF+DH]

    Strassen引入seven magic product 分别是P1, P2, P3 ,P4, P5, P6, P7
    P1 = A(F-H)
    P2 = (A+B)H
    P3 = (C+D)E
    P4 = D(G-E)
    P5 = (A+D)(E+H)
    P6 = (B-D)(G+H)
    P7 = (A-C)(E+F)

    这样XY = 
    [P5+P4-P2+P6, P1+P2;
    P3+P4, P1+P5-P3-P7]

    然后通过递归的策略计算矩阵的相乘,递归的出口是n = 1.

    关键点就是这些,附上代码吧。

    [java] view plaincopy在CODE上查看代码片派生到我的代码片
     
      1. //multiply matrix multiplication  
      2. import java.util.Scanner;  
      3. public class Strassen{  
      4.     public Strassen(){}  
      5.   
      6.   
      7.     /** split a parent matrix into child matrics8*/  
      8.     public static void split(int[][] P, int[][] C, int iB, int jB){  
      9.         for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)  
      10.             for(int j1=0, j2=jB; j1<C.length; j1++, j2++)  
      11.                 C[i1][j1] = P[i2][j2];  
      12.     }  
      13.   
      14.   
      15.     /**join child matric into parent matrix*/  
      16.     public static void join(int[][] C, int[][] P, int iB, int jB){  
      17.         for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)  
      18.             for(int j1=0, j2=jB; j1<C.length; j1++, j2++)  
      19.                 P[i2][j2]=C[i1][j1];   
      20.     }  
      21.   
      22.   
      23.     /**add two matrics into one*/  
      24.     public static int[][] add(int[][] A, int[][] B){  
      25.         //A and B has the same dimension  
      26.         int n = A.length;  
      27.         int[][] C = new int[n][n];  
      28.         for (int i=0; i<n; i++)  
      29.             for(int j=0; j<n; j++)  
      30.                 C[i][j] = A[i][j] + B[i][j];  
      31.                   
      32.         return C;          
      33.     }  
      34.   
      35.   
      36.   
      37.   
      38.     //subtract one matric by another  
      39.     public static int[][] sub(int[][] A, int[][] B){  
      40.         //A and B has the same dimension  
      41.         int n = A.length;  
      42.         int[][] C = new int[n][n];  
      43.         for (int i=0; i<n; i++)  
      44.             for(int j=0; j<n; j++)  
      45.                 C[i][j] = A[i][j] - B[i][j];  
      46.         return C;     
      47.     }  
      48.   
      49.   
      50.     //Multiply matrix  
      51.     public static int[][] multiply(int[][] A, int[][] B){  
      52.         int n = A.length;  
      53.         int[][] R = new int[n][n];  
      54.   
      55.   
      56.         /**exit*/  
      57.         if(n==1)  
      58.             R[0][0] = A[0][0]+B[0][0];  
      59.   
      60.   
      61.         else{  
      62.             //divide A into 4 submatrix  
      63.             int[][] A11 = new int[n/2][n/2];  
      64.             int[][] A12 = new int[n/2][n/2];  
      65.             int[][] A21 = new int[n/2][n/2];  
      66.             int[][] A22 = new int[n/2][n/2];  
      67.   
      68.   
      69.             split(A, A11, 00);  
      70.             split(A, A12, 0, n/2);  
      71.             split(A, A21, n/20);  
      72.             split(A, A22, n/2, n/2);  
      73.   
      74.   
      75.             //divide B into 4 submatric  
      76.             int[][] B11 = new int[n/2][n/2];  
      77.             int[][] B12 = new int[n/2][n/2];  
      78.             int[][] B21 = new int[n/2][n/2];  
      79.             int[][] B22 = new int[n/2][n/2];  
      80.   
      81.   
      82.             split(B, B11, 00);  
      83.             split(B, B12, 0, n/2);  
      84.             split(B, B21, n/20);  
      85.             split(B, B22, n/2, n/2);  
      86.   
      87.   
      88.             //seven magic products  
      89.             int[][] P1 = multiply(A11, sub(B12, B22));  
      90.             int[][] P2 = multiply(add(A11,A12), B22);  
      91.             int[][] P3 = multiply(add(A21, A22), B11);  
      92.             int[][] P4 = multiply(A22, sub(B21, B11));  
      93.             int[][] P5 = multiply(add(A11, A22), add(B11, B22));  
      94.             int[][] P6 = multiply(sub(A12, A22), add(B21, B22));  
      95.             int[][] P7 = multiply(sub(A11, A21), add(B11, B12));  
      96.   
      97.   
      98.   
      99.   
      100.             //new 4 submatrix  
      101.             int[][] R11 = add(add(P5, sub(P4, P2)), P6);  
      102.             int[][] R12 = add(P1, P2);  
      103.             int[][] R21 = add(P3, P4);  
      104.             int[][] R22 = sub(sub(add(P1, P5), P3), P7);  
      105.   
      106.   
      107.             //joint together  
      108.             join(R11, R, 00);  
      109.             join(R12, R, 0, n/2);  
      110.             join(R21, R, n/20);  
      111.             join(R22, R, n/2, n/2);  
      112.   
      113.         }  
      114.         return R;  
      115.     }  
      116.   
      117.   
      118.     //main   
      119.     public static void main(String[] args){  
      120.           
      121.         Scanner scan = new Scanner(System.in);  
      122.         System.out.println("Strassen Multiplication Algorithm Test\n");  
      123.         Strassen s = new Strassen();  
      124.    
      125.   
      126.   
      127.         System.out.println("Fetch the matric A and B...");  
      128.         int N = scan.nextInt();  
      129.         int[][] A = new int[N][N];  
      130.         int[][] B = new int[N][N];  
      131.   
      132.   
      133.         for (int i = 0; i < N; i++)  
      134.             for (int j = 0; j < N; j++)  
      135.                 A[i][j] = scan.nextInt();  
      136.   
      137.   
      138.         for (int i = 0; i < N; i++)  
      139.             for (int j = 0; j < N; j++)  
      140.                 B[i][j] = scan.nextInt();  
      141.   
      142.   
      143.         System.out.println("Fetch Completed!");  
      144.    
      145.         int[][] C = s.multiply(A, B);  
      146.           
      147.         System.out.println("\nmatrices A = ");  
      148.         for (int i = 0; i < N; i++){  
      149.             for (int j = 0; j < N; j++)  
      150.                 System.out.print(A[i][j] +" ");  
      151.             System.out.println();  
      152.         }  
      153.   
      154.   
      155.         System.out.println("\nmatrices B =");  
      156.         for (int i = 0; i < N; i++) {  
      157.             for (int j = 0; j < N; j++)  
      158.                 System.out.print(B[i][j] +" ");  
      159.             System.out.println();  
      160.         }  
      161.    
      162.         System.out.println("\nProduct of matrices A and  B  = ");  
      163.         for (int i = 0; i < N; i++)  
      164.         {  
      165.             for (int j = 0; j < N; j++)  
      166.                 System.out.print(C[i][j] +" ");  
      167.             System.out.println();  
      168.         }  
      169.     }  
      170. }  
  • 相关阅读:
    redis修改密码
    redis配置
    django中日志配置
    django中缓存配置
    navicat批量导入数据
    django添加REST_FRAMEWORK 接口浏览
    django验证码配置与使用
    LUA_OBJECT
    LUA comment
    lua-redis
  • 原文地址:https://www.cnblogs.com/guanghuiz/p/3746886.html
Copyright © 2020-2023  润新知