• Strassen矩阵乘法


    有用资料

    此链接Strassen算法的原理讲解的更加透彻

    http://www.ituring.com.cn/article/17978

    矩阵乘法的定义

    矩阵乘法,A*B=C,其中:

    这里写图片描述这里写图片描述

    那么乘法的定义呢??A矩阵的一行与B矩阵的一列点乘和为C的一个元素。用图形表示是最直观的,其定义就如下图:

    这里写图片描述

    基本矩阵乘法

    那么由上面图中的公式,我们很容易得到基本矩阵相乘的伪代码:

    for i = 1 to col // row
        for j = 1 to row // col
            tmp = 0
            for k = 1 to col 
                tmp += A[i*col + k]*B[j + k*row]
            end for
            C[i*col + j] = tmp
        end for
    end for

    矩阵乘法的改进

    怎样对矩阵相乘的算法进行改进呢???一个想当然的想法:分块矩阵相乘!!!

    这里写图片描述

    这里写图片描述

    一共有8个(n/2)(n/2)的矩阵乘法和,4个(n/2)(n/2)的矩阵加法。再次使用以前的Master Method,

    T(n) = 8T(n/2) + T(n^2)

    这里写图片描述

    由此可见,算法的时间复杂度并没有下降,怎么办呢??

    下面就到了伟大的Strassen’s Idea了。谁也不知道他是怎么想出来这个算法的,但是呢,一个指导思想是,要想降低算法的时间复杂度,就要设法降低乘法的次数,这位Strassen做到了,将8次乘法减少到7次!

    Strassen矩阵乘法

    这里写图片描述

    将乘法从八次减少到了7次,这个差值1,看起来不起眼,但是这可是

    这里写图片描述

    那么我们再来直观地看看这个一次乘法的减少在理论上的性能提升:

    这里写图片描述

    由此可见,其在性能上的提升是有多么巨大,在MIT算法导论课件上说,这个算法在n>30时会显示出效果,但是这要跟编程方法有关的,算法好不等于实现性能好。

    算法的思路就是,当将矩阵分块再分块,当大小为2*2时计算,然后返回,但要注意的是计算每一个P的时候的乘法都是一个小的矩阵的相乘,也要用Strassen方法,所以这是一个递归方法。

    Strassen 矩阵乘法实现

    完整代码一:数组表示

    #include <iostream>  
    using namespace std;   
    
    const int N = 6;     //Define the size of the Matrix  
    
     //input 
    template<typename T>   
    void input(int n, T p[][N]) {   
         for(int i=0; i<n; i++) {   
             cout<<"Please Input Line "<<i+1<<endl;   
            for(int j=0; j<n; j++) {   
                cin>>p[i][j];   
             }           
          }   
    }   
    
    //output  
    template<typename T>   
    void output(int n, T C[][N]) {   
          cout<<"The Output Matrix is :"<<endl;   
         for(int i=0; i<n; i++) {   
            for(int j=0; j<n; j++) {   
                cout<<C[i][j]<<""<<endl;           
             }           
          }        
    }   
    
    //Matrix_Multiply  
    template<typename T>   
    void Matrix_Multiply(T A[][N], T B[][N], T C[][N]) {  //Calculating A*B->C  
         for(int i=0; i<2; i++) {   
            for(int j=0; j<2; j++) {   
                C[i][j] = 0;         
               for(int t=0; t<2; t++) {   
                   C[i][j] = C[i][j] + A[i][t]*B[t][j];           
                }     
             }           
          }   
    }   
    
    //Matrix_Add  
    template <typename T>   
    void Matrix_Add(int n, T X[][N], T Y[][N], T Z[][N]) {   
         for(int i=0; i<n; i++) {   
            for(int j=0; j<n; j++) {   
                Z[i][j] = X[i][j] + Y[i][j];           
             }           
          }        
    }   
    
    //Matrix_Sub  
    template <typename T>   
    void Matrix_Sub(int n, T X[][N], T Y[][N], T Z[][N]) {   
         for(int i=0; i<n; i++) {   
            for(int j=0; j<n; j++) {   
                Z[i][j] = X[i][j] - Y[i][j];           
             }           
          }        
    }   
    
    
    //Strassen  
    template <typename T>   
    void Strassen(int n, T A[][N], T B[][N], T C[][N]) {   
          T A11[N][N], A12[N][N], A21[N][N], A22[N][N];   
          T B11[N][N], B12[N][N], B21[N][N], B22[N][N];        
          T C11[N][N], C12[N][N], C21[N][N], C22[N][N];   
          T M1[N][N], M2[N][N], M3[N][N], M4[N][N], M5[N][N], M6[N][N], M7[N][N];   
          T AA[N][N], BB[N][N];   
    
         if(n == 2) {  //2-order  
             Matrix_Multiply(A, B, C);        
          } else {   
            //将矩阵A和B分成阶数相同的四个子矩阵,即分治思想。  
            for(int i=0; i<n/2; i++) {   
               for(int j=0; j<n/2; j++) {   
                   A11[i][j] = A[i][j];   
                   A12[i][j] = A[i][j+n/2];   
                   A21[i][j] = A[i+n/2][j];   
                   A22[i][j] = A[i+n/2][j+n/2];   
    
                   B11[i][j] = B[i][j];   
                   B12[i][j] = B[i][j+n/2];   
                   B21[i][j] = B[i+n/2][j];   
                   B22[i][j] = B[i+n/2][j+n/2];       
                }           
             }     
    
            //Calculate M1 = (A0 + A3) × (B0 + B3)  
             Matrix_Add(n/2, A11, A22, AA);   
             Matrix_Add(n/2, B11, B22, BB);   
             Strassen(n/2, AA, BB, M1);   
    
            //Calculate M2 = (A2 + A3) × B0  
             Matrix_Add(n/2, A21, A22, AA);   
             Strassen(n/2, AA, B11, M2);   
    
            //Calculate M3 = A0 × (B1 - B3)  
             Matrix_Sub(n/2, B12, B22, BB);   
             Strassen(n/2, A11, BB, M3);   
    
            //Calculate M4 = A3 × (B2 - B0)  
             Matrix_Sub(n/2, B21, B11, BB);   
             Strassen(n/2, A22, BB, M4);   
    
            //Calculate M5 = (A0 + A1) × B3  
             Matrix_Add(n/2, A11, A12, AA);   
             Strassen(n/2, AA, B22, M5);   
    
            //Calculate M6 = (A2 - A0) × (B0 + B1)  
             Matrix_Sub(n/2, A21, A11, AA);   
             Matrix_Add(n/2, B11, B12, BB);   
             Strassen(n/2, AA, BB, M6);   
    
            //Calculate M7 = (A1 - A3) × (B2 + B3)  
             Matrix_Sub(n/2, A12, A22, AA);   
             Matrix_Add(n/2, B21, B22, BB);   
             Strassen(n/2, AA, BB, M7);   
    
            //Calculate C0 = M1 + M4 - M5 + M7  
             Matrix_Add(n/2, M1, M4, AA);   
             Matrix_Sub(n/2, M7, M5, BB);   
             Matrix_Add(n/2, AA, BB, C11);   
    
            //Calculate C1 = M3 + M5  
             Matrix_Add(n/2, M3, M5, C12);   
    
            //Calculate C2 = M2 + M4  
             Matrix_Add(n/2, M2, M4, C21);   
    
            //Calculate C3 = M1 - M2 + M3 + M6  
             Matrix_Sub(n/2, M1, M2, AA);   
             Matrix_Add(n/2, M3, M6, BB);   
             Matrix_Add(n/2, AA, BB, C22);   
    
            //Set the result to C[][N]  
            for(int i=0; i<n/2; i++) {   
               for(int j=0; j<n/2; j++) {   
                   C[i][j] = C11[i][j];   
                   C[i][j+n/2] = C12[i][j];   
                   C[i+n/2][j] = C21[i][j];   
                   C[i+n/2][j+n/2] = C22[i][j];           
                }           
             }   
          }   
    }
    
    int main() 
    {   
        //Define three Matrices  
        int A[N][N],B[N][N],C[N][N];       
    
        //对A和B矩阵赋值,随便赋值都可以,测试用  
        for(int i=0; i<N; i++) {   
           for(int j=0; j<N; j++) {   
               A[i][j] = i * j;   
               B[i][j] = i * j;      
            }           
         }   
    
        //调用Strassen方法实现C=A*B  
         Strassen(N, A, B, C);   
    
        //输出矩阵C中值  
         output(N, C);   
    
         system("pause");   
         return 0;   
    }   
    
    

    完整代码二:指针表示

    /*
    Strassen Algorithm Implementation in C++
    Coded By: Seyyed Hossein Hasan Pour MatiKolaee in May 5 2010 .
    Mazandaran University of Science and Technology,Babol,Mazandaran,Iran
    --------------------------------------------
    Email : Master.huricane@gmail.com
    YM    : Deathmaster_nemessis@yahoo.com
    Updated may 09 2010.
    */
    #include <iostream>
    #include <cstdlib>
    #include <iomanip>
    #include <ctime>
    #include <windows.h>
    using namespace std;
    
    int Strassen(int n, int** MatrixA, int ** MatrixB, int ** MatrixC);//Multiplies Two Matrices recrusively.
    int ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//Adds two Matrices, and places the result in another Matrix
    int SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//subtracts two Matrices , and places  the result in another Matrix
    int MUL(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//Multiplies two matrices in conventional way.
    void FillMatrix( int** matrix1, int** matrix2, int length);//Fills Matrices with random numbers.
    void PrintMatrix( int **MatrixA, int MatrixSize );//prints the Matrix content.
    
    int main()
    {
    
        int MatrixSize = 0;
    
        int** MatrixA;
        int** MatrixB;
        int** MatrixC;
    
        clock_t startTime_For_Normal_Multipilication ;
        clock_t endTime_For_Normal_Multipilication ;
    
        clock_t startTime_For_Strassen ;
        clock_t endTime_For_Strassen ;
    
        time_t start,end;
    
        srand(time(0));
    
        cout<<setw(45)<<"In the name of GOD";
        cout<<endl<<setw(60)<<"Strassen Algorithm Implementation in C++ "
            <<endl<<endl<<setw(50)<<"By Seyyed Hossein Hasan Pour"
            <<endl<<setw(60)<<"Mazandaran University of Science and Technology"
            <<endl<<setw(40)<<"May 9 2010";
    
        cout<<"
    Please Enter your Matrix Size(must be in a power of two(eg:32,64,512,..): ";
        cin>>MatrixSize;
    
        int N = MatrixSize;//for readiblity.
    
    
        MatrixA = new int *[MatrixSize];
        MatrixB = new int *[MatrixSize];
        MatrixC = new int *[MatrixSize];
    
        for (int i = 0; i < MatrixSize; i++)
        {
            MatrixA[i] = new int [MatrixSize];
            MatrixB[i] = new int [MatrixSize];
            MatrixC[i] = new int [MatrixSize];
        }
    
        FillMatrix(MatrixA,MatrixB,MatrixSize);
    
      //*******************conventional multiplication test
            cout<<"Phase I started:  "<< (startTime_For_Normal_Multipilication = clock());
    
            MUL(MatrixA,MatrixB,MatrixC,MatrixSize);
    
            cout<<"
    Phase I ended: "<< (endTime_For_Normal_Multipilication = clock());
    
            cout<<"
    Matrix Result... 
    ";
            PrintMatrix(MatrixC,MatrixSize);
    
      //*******************Strassen multiplication test
            cout<<"
    Multiplication started: "<< (startTime_For_Strassen = clock());
    
            Strassen( N, MatrixA, MatrixB, MatrixC );
    
            cout<<"
    Multiplication: "<<(endTime_For_Strassen = clock());
    
    
        cout<<"
    Matrix Result... 
    ";
        PrintMatrix(MatrixC,MatrixSize);
    
        cout<<"Matrix size "<<MatrixSize;
        cout<<"
    Normal mode "<<(endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication)<<" Clocks.."<<(endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication)/CLOCKS_PER_SEC<<" Sec";
        cout<<"
    Strassen mode "<<(endTime_For_Strassen - startTime_For_Strassen)<<" Clocks.."<<(endTime_For_Strassen - startTime_For_Strassen)/CLOCKS_PER_SEC<<" Sec
    ";
        system("Pause");
        return 0;
    
    }
    /*
    in order to be able to create a matrix without any limitaion in c++,
    one way is to create it using pointers.
    as you see by using a pointer to pointer strategy we can make a multi-
    dimensional Matrix of any size . The notation also makes us capable of
    creating a matrix with VARIABLE size at runtime ,meaning we can resize
    the size of our matrix at runtime , shrink it or increase it , your choice.
    what we do is simple , first we make a pointer of pointer variable , this
    means that our first pointer will point to another pointer which again
    this pointer ,points to sth else(we can make it point to an array) .
    int **A;
    will declare the variable , we now need to expand it .
    now make a pointer based array and allocate the memory dynamicly
    
    A = new int *[desired_array_row];
    
    this gives us a one diminsional pointer based array,now you want a 2D array?
    big deal,lets make one.
    we use for() to achieve this goal , remember when i said we are going to make
    a variable which is a pointer of pointer ? which meant any location pointed to somewhere else
    , we made a pointer based array , a one diminsional one , just up there ,
    and you know this fatct that an array is consits of individual blocks right?
    and the fact that each block can be used just like a solo variable.
    so simply if we could write
    A = new int *[any_size];
    cant we do it to all of our indiviual array blocks which are just like the solo variable ?
    so this means that if we could do it with A, and get an array , we can use the same method
    to make different arrays for different block of the array we made in first place.
    we use for() to iterate through all of the blocks of the previously made array, and
    then for each block we create a single array .
    
    for ( int i = 0; i < desired_array_row; i++)
    A[i] = new int [desired_column_size];
    
    after this for , we can enjoy our 2D array wich can be access like any ordinary array we know.
    just use the conventional notation for accessing array blocks for either reading or writing.( A[i][j])
    and remember to free the space we allocated for our 2D array at the end of the program .
    we do such a thing this way:
    
    for ( int i = 0; i < your_array_row; i++)
    {
        delete [] A[i];
    }
    delete[] A;
    
    .using this method you can make any N-diminsional array, you just need to use for with right iteration.
    
    
    
    */
    int Strassen(int N, int **MatrixA, int **MatrixB, int **MatrixC)
    {
    
            int HalfSize = N/2;
            int newSize = N/2;
    
            if ( N <= 64 )//choosing the threshhold is extremely important, try N<=2 to see the result
            {
                MUL(MatrixA,MatrixB,MatrixC,N);
            }
            else
            {
                int** A11;
                int** A12;
                int** A21;
                int** A22;
    
                int** B11;
                int** B12;
                int** B21;
                int** B22;
    
                int** C11;
                int** C12;
                int** C21;
                int** C22;
    
                int** M1;
                int** M2;
                int** M3;
                int** M4;
                int** M5;
                int** M6;
                int** M7;
                int** AResult;
                int** BResult;
    
                //making a 1 diminsional pointer based array.
                A11 = new int *[newSize];
                A12 = new int *[newSize];
                A21 = new int *[newSize];
                A22 = new int *[newSize];
    
                B11 = new int *[newSize];
                B12 = new int *[newSize];
                B21 = new int *[newSize];
                B22 = new int *[newSize];
    
                C11 = new int *[newSize];
                C12 = new int *[newSize];
                C21 = new int *[newSize];
                C22 = new int *[newSize];
    
                M1 = new int *[newSize];
                M2 = new int *[newSize];
                M3 = new int *[newSize];
                M4 = new int *[newSize];
                M5 = new int *[newSize];
                M6 = new int *[newSize];
                M7 = new int *[newSize];
    
                AResult = new int *[newSize];
                BResult = new int *[newSize];
    
                int newLength = newSize;
    
                //making that 1 diminsional pointer based array , a 2D pointer based array
                for ( int i = 0; i < newSize; i++)
                {
                    A11[i] = new int[newLength];
                    A12[i] = new int[newLength];
                    A21[i] = new int[newLength];
                    A22[i] = new int[newLength];
    
                    B11[i] = new int[newLength];
                    B12[i] = new int[newLength];
                    B21[i] = new int[newLength];
                    B22[i] = new int[newLength];
    
                    C11[i] = new int[newLength];
                    C12[i] = new int[newLength];
                    C21[i] = new int[newLength];
                    C22[i] = new int[newLength];
    
                    M1[i] = new int[newLength];
                    M2[i] = new int[newLength];
                    M3[i] = new int[newLength];
                    M4[i] = new int[newLength];
                    M5[i] = new int[newLength];
                    M6[i] = new int[newLength];
                    M7[i] = new int[newLength];
    
                    AResult[i] = new int[newLength];
                    BResult[i] = new int[newLength];
    
    
                }
                //splitting input Matrixes, into 4 submatrices each.
                for (int i = 0; i < N / 2; i++)
                {
                    for (int j = 0; j < N / 2; j++)
                    {
                        A11[i][j] = MatrixA[i][j];
                        A12[i][j] = MatrixA[i][j + N / 2];
                        A21[i][j] = MatrixA[i + N / 2][j];
                        A22[i][j] = MatrixA[i + N / 2][j + N / 2];
    
                        B11[i][j] = MatrixB[i][j];
                        B12[i][j] = MatrixB[i][j + N / 2];
                        B21[i][j] = MatrixB[i + N / 2][j];
                        B22[i][j] = MatrixB[i + N / 2][j + N / 2];
    
                    }
                }
    
                //here we calculate M1..M7 matrices .
                //M1[][]
                ADD( A11,A22,AResult, HalfSize);
                ADD( B11,B22,BResult, HalfSize);
                Strassen( HalfSize, AResult, BResult, M1 ); //now that we need to multiply this , we use the strassen itself .
    
    
                //M2[][]
                ADD( A21,A22,AResult, HalfSize);              //M2=(A21+A22)B11
                Strassen(HalfSize, AResult, B11, M2);       //Mul(AResult,B11,M2);
    
                //M3[][]
                SUB( B12,B22,BResult, HalfSize);              //M3=A11(B12-B22)
                Strassen(HalfSize, A11, BResult, M3);       //Mul(A11,BResult,M3);
    
                //M4[][]
                SUB( B21, B11, BResult, HalfSize);           //M4=A22(B21-B11)
                Strassen(HalfSize, A22, BResult, M4);       //Mul(A22,BResult,M4);
    
                //M5[][]
                ADD( A11, A12, AResult, HalfSize);           //M5=(A11+A12)B22
                Strassen(HalfSize, AResult, B22, M5);       //Mul(AResult,B22,M5);
    
    
                //M6[][]
                SUB( A21, A11, AResult, HalfSize);
                ADD( B11, B12, BResult, HalfSize);             //M6=(A21-A11)(B11+B12)
                Strassen( HalfSize, AResult, BResult, M6);    //Mul(AResult,BResult,M6);
    
                //M7[][]
                SUB(A12, A22, AResult, HalfSize);
                ADD(B21, B22, BResult, HalfSize);             //M7=(A12-A22)(B21+B22)
                Strassen(HalfSize, AResult, BResult, M7);     //Mul(AResult,BResult,M7);
    
                //C11 = M1 + M4 - M5 + M7;
                ADD( M1, M4, AResult, HalfSize);
                SUB( M7, M5, BResult, HalfSize);
                ADD( AResult, BResult, C11, HalfSize);
    
                //C12 = M3 + M5;
                ADD( M3, M5, C12, HalfSize);
    
                //C21 = M2 + M4;
                ADD( M2, M4, C21, HalfSize);
    
                //C22 = M1 + M3 - M2 + M6;
                ADD( M1, M3, AResult, HalfSize);
                SUB( M6, M2, BResult, HalfSize);
                ADD( AResult, BResult, C22, HalfSize);
    
    
                //at this point , we have calculated the c11..c22 matrices, and now we are going to
                //put them together and make a unit matrix which would describe our resulting Matrix.
                for (int i = 0; i < N/2 ; i++)
                {
                    for (int j = 0 ; j < N/2 ; j++)
                    {
                        MatrixC[i][j] = C11[i][j];
                        MatrixC[i][j + N / 2] = C12[i][j];
                        MatrixC[i + N / 2][j] = C21[i][j];
                        MatrixC[i + N / 2][j + N / 2] = C22[i][j];
                    }
                }
    
                // dont forget to free the space we alocated for matrices,
                for (int i = 0; i < newLength; i++)
                {
                    delete[] A11[i];delete[] A12[i];delete[] A21[i];
                    delete[] A22[i];
    
                    delete[] B11[i];delete[] B12[i];delete[] B21[i];
                    delete[] B22[i];
                    delete[] C11[i];delete[] C12[i];delete[] C21[i];
                    delete[] C22[i];
                    delete[] M1[i];delete[] M2[i];delete[] M3[i];delete[] M4[i];
                    delete[] M5[i];delete[] M6[i];delete[] M7[i];
                    delete[] AResult[i];delete[] BResult[i] ;
                }
                    delete[] A11;delete[] A12;delete[] A21;delete[] A22;
                    delete[] B11;delete[] B12;delete[] B21;delete[] B22;
                    delete[] C11;delete[] C12;delete[] C21;delete[] C22;
                    delete[] M1;delete[] M2;delete[] M3;delete[] M4;delete[] M5;
                    delete[] M6;delete[] M7;
                    delete[] AResult;
                    delete[] BResult ;
    
    
            }//end of else
    
    
        return 0;
    }
    
    int ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
    {
        for ( int i = 0; i < MatrixSize; i++)
        {
            for ( int j = 0; j < MatrixSize; j++)
            {
                MatrixResult[i][j] =  MatrixA[i][j] + MatrixB[i][j];
            }
        }
        return 0;
    }
    
    int SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
    {
        for ( int i = 0; i < MatrixSize; i++)
        {
            for ( int j = 0; j < MatrixSize; j++)
            {
                MatrixResult[i][j] =  MatrixA[i][j] - MatrixB[i][j];
            }
        }
        return 0;
    }
    
    int MUL( int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
    {
        for (int i=0;i<MatrixSize ;i++)
            {
                  for (int j=0;j<MatrixSize ;j++)
                  {
                       MatrixResult[i][j]=0;
                       for (int k=0;k<MatrixSize ;k++)
                       {
                              MatrixResult[i][j]=MatrixResult[i][j]+MatrixA[i][k]*MatrixB[k][j];
                       }
                  }
            }
        return 0;
    }
    
    void FillMatrix( int** MatrixA, int** MatrixB, int length)
    {
        for(int row = 0; row<length; row++)
        {
            for(int column = 0; column<length; column++)
            {
    
               MatrixB[row][column] = (MatrixA[row][column] = rand() %5);
                //matrix2[row][column] = rand() % 2;//ba hazfe in khat 50% afzayeshe soorat khahim dasht
            }
    
        }
    }
    void PrintMatrix(int **MatrixA,int MatrixSize)
    {
        cout<<endl;
           for(int row = 0; row<MatrixSize; row++)
            {
                for(int column = 0; column<MatrixSize; column++)
                {
    
    
                    cout<<MatrixA[row][column]<<"	";
                    if ((column+1)%((MatrixSize)) == 0)
                        cout<<endl;
                }
    
            }
           cout<<endl;
    }
    

    完整代码2:数组表示:

    
    

    版权声明:本文为博主原创文章,未经博主允许不得转载。

  • 相关阅读:
    nyoj891找点(贪心)
    spark streaming方法
    spark submit打印gc信息
    spark dataframe方法解释
    structed streaming基础---跳过的坑
    scala学习---2
    增量式编码器定时器配置和速度计算的处理方法
    三次样条插补的实现
    增量式编码器计数的过零点处理问题
    串口发送带有使能引脚的注意事项
  • 原文地址:https://www.cnblogs.com/yangquanhui/p/4937467.html
Copyright © 2020-2023  润新知