• strassen 算法实现矩阵相乘,当输入为偶数,奇数,任意数的实现


    1.使用随机数生成两个矩阵,并实现输出方法。

    public void makeMatrix(int[][] matrixA, int[][] matrixB, int length){  //生成矩阵
            Random random=new Random();
            for(int i=0;i<length;i++){
                for (int j=0;j<length;j++){
                    matrixA[i][j]=random.nextInt(5);
                    matrixB[i][j]=random.nextInt(5);
                }
            }
        }
        public  void printMatrix(int[][] matrixA,int length){ //输出
            for(int i=0;i<length;i++){
                for (int j=0;j<length;j++){
                    System.out.print(matrixA[i][j]+" ");
                    if((j+1)%length==0)
                        System.out.println();
                }
            }
        }

    2.使用Strassen算法需要涉及到矩阵的加减。所以先准备好方法。

    public void add(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
            for(int i=0;i<length;i++) {
                for (int j = 0; j < length; j++) {
                    matrixC[i][j]= matrixA[i][j]+ matrixB[i][j];
                }
            }
        }
        public void jian(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
            for(int i=0;i<length;i++) {
                for (int j = 0; j < length; j++) {
                    matrixC[i][j]= matrixA[i][j] - matrixB[i][j];
                }
            }
        }
    
        public void cheng(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
            for(int i=0;i<length;i++) {
                for (int j = 0; j < length; j++) {
                    matrixC[i][j]=0;
                    for(int k=0;k<length;k++){
                        matrixC[i][j] = matrixC[i][j]+ matrixA[i][k] * matrixB[k][j];
                    }
    
                }
            }
        }

    3.当矩阵的阶数为 2的k次方

     //阶数为 2 的 K 次方的时候 Strassen算法
        public void strassen(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
            int newsize=N/2;
            if(N==2){
                cheng(matrixA,matrixB,matrixC,N);
                return;
            }
            int[][] A11=new int[newsize][newsize];
            int[][] A12=new int[newsize][newsize];
            int[][] A21=new int[newsize][newsize];
            int[][] A22=new int[newsize][newsize];
    
            int[][] B11=new int[newsize][newsize];
            int[][] B12=new int[newsize][newsize];
            int[][] B21=new int[newsize][newsize];
            int[][] B22=new int[newsize][newsize];
    
            int[][] C11=new int[newsize][newsize];
            int[][] C12=new int[newsize][newsize];
            int[][] C21=new int[newsize][newsize];
            int[][] C22=new int[newsize][newsize];
    
            int[][] M1=new int[newsize][newsize];
            int[][] M2=new int[newsize][newsize];
            int[][] M3=new int[newsize][newsize];
            int[][] M4=new int[newsize][newsize];
            int[][] M5=new int[newsize][newsize];
            int[][] M6=new int[newsize][newsize];
            int[][] M7=new int[newsize][newsize];
    
            int[][] Aresult=new int[newsize][newsize];
            int[][] Bresult=new int[newsize][newsize];
    
            //分别给 A11 A12 A21 A22赋值
            for(int i=0;i<N/2;i++){
                for(int j=0;j<N/2;j++){
                    A11[i][j]=matrixA[i][j];
                    A12[i][j]=matrixA[i][j+N/2];
                    A21[i][j]=matrixA[i+N/2][j];
                    A22[i][j]=matrixA[i+N/2][j+N/2];
    
                    B11[i][j]=matrixB[i][j];
                    B12[i][j]=matrixB[i][j+N/2];
                    B21[i][j]=matrixB[i+N/2][j];
                    B22[i][j]=matrixB[i+N/2][j+N/2];
                }
            }
    
            //计算M1 到M7
            add(A11,A22,Aresult,newsize);
            add(B11,B22,Bresult,newsize);
            strassen(Aresult,Bresult,M1,newsize);
    
            //M2
            add(A21,A22,Aresult,newsize);
            strassen(Aresult,B11,M2,newsize);
    
            //M3
            jian(B12,B22,Bresult,newsize);
            strassen(A11,Bresult,M3,newsize);
    
            //M4
            jian(B21,B11,Bresult,newsize);
            strassen(A22,Bresult,M4,newsize);
    
            //M5
            add(A11,A12,Aresult,newsize);
            strassen(Aresult,B22,M5,newsize);
    
            //M6
            jian(A21,A11,Aresult,newsize);
            add(B11,B12,Bresult,newsize);
            strassen(Aresult,Bresult,M6,newsize);
    
            //M7
            jian(A12,A22,Aresult,newsize);
            add(B21,B22,Bresult,newsize);
            strassen(Aresult,Bresult,M7,newsize);
    
            //C11
            add(M1,M4,Aresult,newsize);
            jian(M5,M7,Bresult,newsize);
            jian(Aresult,Bresult,C11,newsize);
    
            //C12
            add(M3,M5,C12,newsize);
    
            //C21
            add(M2,M4,C21,newsize);
    
            //C22
            add(M1,M3,Aresult,newsize);
            jian(M2,M6,Bresult,newsize);
            jian(Aresult,Bresult,C22,newsize);
            //把C的值填充
            for(int i=0;i<N/2;i++){
                for(int j=0;j<N/2;j++){
                    matrixC[i][j]=C11[i][j];
                    matrixC[i][j+N/2]=C12[i][j];
                    matrixC[i+N/2][j]=C21[i][j];
                    matrixC[i+N/2][j+N/2]=C22[i][j];
                }
            }
        }

    4.当阶数为偶数的时候

    假设阶数为  n   ,所以 n=m*2^k    。此时,可以将偶数的矩阵拆分为 m*m个2*k的矩阵。大矩阵使用传统方法,小矩阵使用Strassen算法。

    //为阶数为偶数时的 Strassen算法
        public void evenNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
            int[] splits=getK(N);
            int m = splits[1];
            int k = splits[0];
            int jie=(int) Math.pow(2,k);
            //可以拆分为 m*m 个 2^^k 阶矩阵
            Object[][] TA = new Object[m][m];
            Object[][] TB = new Object[m][m];
            Object[][] TC = new Object[m][m];
            for(int hang=0;hang<m;hang++){
                for (int lie=0;lie<m;lie++){
                    int[][] matrixMA=new int[jie][jie];
                    int[][] matrixMB=new int[jie][jie];
                    //给矩阵MA ,MB 赋值
                    for(int i=0;i<jie;i++){
                        for(int j=0;j<jie;j++){
                            matrixMA[i][j]=matrixA[hang*jie+i][lie*jie+j];
                            matrixMB[i][j]=matrixB[hang*jie+i][lie*jie+j];
                        }
                    }
                    TA[hang][lie]=matrixMA;
                    TB[hang][lie]=matrixMB;
                }
            }
            //Object 数组中存放好了 m*m 的2^k 阶矩阵  所以 TA TB 看做两个矩阵做乘法
            for(int i=0;i<m;i++){
                for(int j=0;j<m;j++){
                    int[][] juzhenC = new int[jie][jie];
                    for(int p=0;p<m;p++){
                        int[][] juzhenA = (int[][])TA[i][p];
                        int[][] juzhenB = (int[][])TB[p][j];
                        int[][] chengres =new int[jie][jie];
                        int[][] addres =new int[jie][jie];
                        strassen(juzhenA,juzhenB,chengres,jie);
                        add(juzhenC,chengres,addres,jie);
                        juzhenC=addres;
                    }
                    TC[i][j]=juzhenC;
                }
            }
    
            //给矩阵C 结果矩阵进行赋值
            for(int hang=0;hang<m;hang++){
                for (int lie=0;lie<m;lie++){
                    int[][] matrixMC=(int[][])TC[hang][lie];
                    //给矩阵MA ,MB 赋值
                    for(int i=0;i<jie;i++){
                        for(int j=0;j<jie;j++){
                            matrixC[hang*jie+i][lie*jie+j]=matrixMC[i][j];
    
                        }
                    }
                }
            }
        }

    5.当阶数为奇数的时候,添加一行一列,放在第一行第一列,并吧 [0][0]位置设为1  ,这样不影响矩阵相乘结果。计算出结果再去掉第一行第一列即可。

     //为阶数为奇数时的 Strassen算法
        public void oddNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
            //扩容矩阵,第一行,第一列增加,A[0,0] 位置的值为0
            N=N+1;
            int[][] newA=new int[N][N];
            int[][] newB=new int[N][N];
            int[][] newC=new int[N][N];
            for(int i=0;i<N;i++){
                for(int j=0;j<N;j++){
                    if(i==0||j==0){
                        newA[i][j]=0;
                        newB[i][j]=0;
                        continue;
                    }
                    newA[i][j]=matrixA[i-1][j-1];
                    newB[i][j]=matrixB[i-1][j-1];
                }
            }
            newA[0][0]=1;
            newB[0][0]=1;
            evenNumber(newA,newB,newC,N);
            for(int i=0;i<N-1;i++) {
                for (int j = 0; j < N-1; j++) {
                    matrixC[i][j]=newC[i+1][j+1];
                }
            }
        }

    6.所有代码如下

    import java.util.Random;
    import java.util.Scanner;
    
    /**
     * 主要有四个方法:
     * 第一问,阶数为2^k 的时候矩阵相乘    strassen
     * 第二问  阶数为偶数次矩阵相乘      evenNumber
     * 第三问  阶数为技数次矩阵相乘      oddNumber
     *
     * 将这些方法封装到一个方法内  allType    这个方法不限制矩阵阶数,任意维度都可以相乘。
     * */
    public class Matrix {
        public void add(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
            for(int i=0;i<length;i++) {
                for (int j = 0; j < length; j++) {
                    matrixC[i][j]= matrixA[i][j]+ matrixB[i][j];
                }
            }
        }
        public void jian(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
            for(int i=0;i<length;i++) {
                for (int j = 0; j < length; j++) {
                    matrixC[i][j]= matrixA[i][j] - matrixB[i][j];
                }
            }
        }
    
        public void cheng(int[][] matrixA,int[][] matrixB,int[][] matrixC,int length){
            for(int i=0;i<length;i++) {
                for (int j = 0; j < length; j++) {
                    matrixC[i][j]=0;
                    for(int k=0;k<length;k++){
                        matrixC[i][j] = matrixC[i][j]+ matrixA[i][k] * matrixB[k][j];
                    }
    
                }
            }
        }
    
        public void makeMatrix(int[][] matrixA, int[][] matrixB, int length){
            Random random=new Random();
            for(int i=0;i<length;i++){
                for (int j=0;j<length;j++){
                    matrixA[i][j]=random.nextInt(5);
                    matrixB[i][j]=random.nextInt(5);
                }
            }
        }
        public  void printMatrix(int[][] matrixA,int length){
            for(int i=0;i<length;i++){
                for (int j=0;j<length;j++){
                    System.out.print(matrixA[i][j]+" ");
                    if((j+1)%length==0)
                        System.out.println();
                }
            }
        }
        /**
         * 计算阶数N   N=m*2^k
         * 返回值为数组   k,m
         * */
        public int[] getK(int N){
            int k=0;
            if(N%2==0){
                k++;
                N=N/2;
            }
            return new int[]{k,N};
        }
    
        /**
         * ============================================================================
         * 此分界线以上为辅助用的方法,下边三个分别为 三种情况对应的算法。
         * 最后一个是 参数为任意数时的方法。
         * */
    
    
        //阶数为 2 的 K 次方的时候 Strassen算法
        public void strassen(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
            int newsize=N/2;
            if(N==2){
                cheng(matrixA,matrixB,matrixC,N);
                return;
            }
            int[][] A11=new int[newsize][newsize];
            int[][] A12=new int[newsize][newsize];
            int[][] A21=new int[newsize][newsize];
            int[][] A22=new int[newsize][newsize];
    
            int[][] B11=new int[newsize][newsize];
            int[][] B12=new int[newsize][newsize];
            int[][] B21=new int[newsize][newsize];
            int[][] B22=new int[newsize][newsize];
    
            int[][] C11=new int[newsize][newsize];
            int[][] C12=new int[newsize][newsize];
            int[][] C21=new int[newsize][newsize];
            int[][] C22=new int[newsize][newsize];
    
            int[][] M1=new int[newsize][newsize];
            int[][] M2=new int[newsize][newsize];
            int[][] M3=new int[newsize][newsize];
            int[][] M4=new int[newsize][newsize];
            int[][] M5=new int[newsize][newsize];
            int[][] M6=new int[newsize][newsize];
            int[][] M7=new int[newsize][newsize];
    
            int[][] Aresult=new int[newsize][newsize];
            int[][] Bresult=new int[newsize][newsize];
    
            //分别给 A11 A12 A21 A22赋值
            for(int i=0;i<N/2;i++){
                for(int j=0;j<N/2;j++){
                    A11[i][j]=matrixA[i][j];
                    A12[i][j]=matrixA[i][j+N/2];
                    A21[i][j]=matrixA[i+N/2][j];
                    A22[i][j]=matrixA[i+N/2][j+N/2];
    
                    B11[i][j]=matrixB[i][j];
                    B12[i][j]=matrixB[i][j+N/2];
                    B21[i][j]=matrixB[i+N/2][j];
                    B22[i][j]=matrixB[i+N/2][j+N/2];
                }
            }
    
            //计算M1 到M7
            add(A11,A22,Aresult,newsize);
            add(B11,B22,Bresult,newsize);
            strassen(Aresult,Bresult,M1,newsize);
    
            //M2
            add(A21,A22,Aresult,newsize);
            strassen(Aresult,B11,M2,newsize);
    
            //M3
            jian(B12,B22,Bresult,newsize);
            strassen(A11,Bresult,M3,newsize);
    
            //M4
            jian(B21,B11,Bresult,newsize);
            strassen(A22,Bresult,M4,newsize);
    
            //M5
            add(A11,A12,Aresult,newsize);
            strassen(Aresult,B22,M5,newsize);
    
            //M6
            jian(A21,A11,Aresult,newsize);
            add(B11,B12,Bresult,newsize);
            strassen(Aresult,Bresult,M6,newsize);
    
            //M7
            jian(A12,A22,Aresult,newsize);
            add(B21,B22,Bresult,newsize);
            strassen(Aresult,Bresult,M7,newsize);
    
            //C11
            add(M1,M4,Aresult,newsize);
            jian(M5,M7,Bresult,newsize);
            jian(Aresult,Bresult,C11,newsize);
    
            //C12
            add(M3,M5,C12,newsize);
    
            //C21
            add(M2,M4,C21,newsize);
    
            //C22
            add(M1,M3,Aresult,newsize);
            jian(M2,M6,Bresult,newsize);
            jian(Aresult,Bresult,C22,newsize);
            //把C的值填充
            for(int i=0;i<N/2;i++){
                for(int j=0;j<N/2;j++){
                    matrixC[i][j]=C11[i][j];
                    matrixC[i][j+N/2]=C12[i][j];
                    matrixC[i+N/2][j]=C21[i][j];
                    matrixC[i+N/2][j+N/2]=C22[i][j];
                }
            }
        }
    
        //为阶数为偶数时的 Strassen算法
        public void evenNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
            int[] splits=getK(N);
            int m = splits[1];
            int k = splits[0];
            int jie=(int) Math.pow(2,k);
            //可以拆分为 m*m 个 2^^k 阶矩阵
            Object[][] TA = new Object[m][m];
            Object[][] TB = new Object[m][m];
            Object[][] TC = new Object[m][m];
            for(int hang=0;hang<m;hang++){
                for (int lie=0;lie<m;lie++){
                    int[][] matrixMA=new int[jie][jie];
                    int[][] matrixMB=new int[jie][jie];
                    //给矩阵MA ,MB 赋值
                    for(int i=0;i<jie;i++){
                        for(int j=0;j<jie;j++){
                            matrixMA[i][j]=matrixA[hang*jie+i][lie*jie+j];
                            matrixMB[i][j]=matrixB[hang*jie+i][lie*jie+j];
                        }
                    }
                    TA[hang][lie]=matrixMA;
                    TB[hang][lie]=matrixMB;
                }
            }
            //Object 数组中存放好了 m*m 的2^k 阶矩阵  所以 TA TB 看做两个矩阵做乘法
            for(int i=0;i<m;i++){
                for(int j=0;j<m;j++){
                    int[][] juzhenC = new int[jie][jie];
                    for(int p=0;p<m;p++){
                        int[][] juzhenA = (int[][])TA[i][p];
                        int[][] juzhenB = (int[][])TB[p][j];
                        int[][] chengres =new int[jie][jie];
                        int[][] addres =new int[jie][jie];
                        strassen(juzhenA,juzhenB,chengres,jie);
                        add(juzhenC,chengres,addres,jie);
                        juzhenC=addres;
                    }
                    TC[i][j]=juzhenC;
                }
            }
    
            //给矩阵C 结果矩阵进行赋值
            for(int hang=0;hang<m;hang++){
                for (int lie=0;lie<m;lie++){
                    int[][] matrixMC=(int[][])TC[hang][lie];
                    //给矩阵MA ,MB 赋值
                    for(int i=0;i<jie;i++){
                        for(int j=0;j<jie;j++){
                            matrixC[hang*jie+i][lie*jie+j]=matrixMC[i][j];
    
                        }
                    }
                }
            }
        }
    
        //为阶数为奇数时的 Strassen算法
        public void oddNumber(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
            //扩容矩阵,第一行,第一列增加,A[0,0] 位置的值为0
            N=N+1;
            int[][] newA=new int[N][N];
            int[][] newB=new int[N][N];
            int[][] newC=new int[N][N];
            for(int i=0;i<N;i++){
                for(int j=0;j<N;j++){
                    if(i==0||j==0){
                        newA[i][j]=0;
                        newB[i][j]=0;
                        continue;
                    }
                    newA[i][j]=matrixA[i-1][j-1];
                    newB[i][j]=matrixB[i-1][j-1];
                }
            }
            newA[0][0]=1;
            newB[0][0]=1;
            evenNumber(newA,newB,newC,N);
            for(int i=0;i<N-1;i++) {
                for (int j = 0; j < N-1; j++) {
                    matrixC[i][j]=newC[i+1][j+1];
                }
            }
        }
    
    
        //综合所有情况
        public void allType(int[][] matrixA,int[][] matrixB,int[][] matrixC,int N){
            if(N%2==1){
                oddNumber(matrixA,matrixB,matrixC,N);
            }else {
                int[] split=getK(N);
                if(split[1]==1){
                    strassen(matrixA,matrixB,matrixC,N);
                }else {
                    evenNumber(matrixA,matrixB,matrixC,N);
                }
            }
        }
    
    
        public static void main(String[] args) {
            System.out.println("请输入矩阵的阶数");
            Scanner input = new Scanner(System.in);
            int matrixSize=input.nextInt();
            int[][] matrixA=new int[matrixSize][matrixSize];
            int[][] matrixB=new int[matrixSize][matrixSize];
            int[][] matrixC=new int[matrixSize][matrixSize];
            int[][] matrixD=new int[matrixSize][matrixSize];
            Matrix t=new Matrix();
            //为矩阵 A B 赋值,填充矩阵
            t.makeMatrix(matrixA,matrixB,matrixSize);
    
            //输出 A B 矩阵
            System.out.println("矩阵A:");
            t.printMatrix(matrixA,matrixSize);
            System.out.println("矩阵B:");
            t.printMatrix(matrixB,matrixSize);
    
    
            //传统乘法,并输出
            t.cheng(matrixA,matrixB,matrixC,matrixSize);
            System.out.println("传统乘法:");
            t.printMatrix(matrixC,matrixSize);
    
            //strassen 乘法,并输出
            t.allType(matrixA,matrixB,matrixD,matrixSize);
            System.out.println("Strassen乘法:");
            t.printMatrix(matrixD,matrixSize);
    
        }
  • 相关阅读:
    [javase学习笔记]-8.7 静态代码块
    QT5.6 编译SQLServer驱动
    mnesia怎样改动表结构
    UVA 1541
    Topcoder SRM625 题解
    android自己定义渐变进度条
    显示vim当前颜色主题
    启动vim不加载.vimrc
    为ubuntu添加多媒体以及flash等等常用包
    linux c:关联变量的双for循环
  • 原文地址:https://www.cnblogs.com/wys-373/p/14111178.html
Copyright © 2020-2023  润新知