Strassen 矩阵乘法算法面临错误

问题描述 投票:0回答:1

我想高效地理解Strassen的矩阵乘法算法。

如果你能帮我解决这个问题。 这是我写的代码:

import java.util.*;

public class Main {

public static int[][] mulMatx(int[][] A, int[][] B) {
    int n = A.length;
    int[][] C = new int[1][1];
    
    if (n==1) {
        C[0][0] = A[0][0] * B[0][0];
    } else {
        
        // Sub Matrixes of Matrix A
        int[][] A11 = new int[n/2][n/2];
        int[][] A12 = new int[n/2][n/2];
        int[][] A21 = new int[n/2][n/2];
        int[][] A22 = new int[n/2][n/2];
    
        // Sub Matrixes of Matrix B
        int[][] B11 = new int[n/2][n/2];
        int[][] B12 = new int[n/2][n/2];
        int[][] B21 = new int[n/2][n/2];
        int[][] B22 = new int[n/2][n/2];
    
        // Split Matrixes of Matrix A
        splitMatx(A, A11, 0, 0);
        splitMatx(A, A12, 0, n/2);
        splitMatx(A, A21, n/2, 0);
        splitMatx(A, A22, n/2, n/2);
    
        // Split Matrixes of Matrix B
        splitMatx(B, B11, 0, 0);
        splitMatx(B, B12, 0, n/2);
        splitMatx(B, B21, n/2, 0);
        splitMatx(B, B22, n/2, n/2);
        
        // Calculate the seven products (M1 to M7) using Strassen recursive calls
        int[][] M1 = mulMatx(addMatx(A11, A22), addMatx(B11, B22));
        int[][] M2 = mulMatx(addMatx(A21, A22), B11);
        int[][] M3 = mulMatx(A11, subMatx(B12, B22));
        int[][] M4 = mulMatx(A22, subMatx(B21, B11));
        int[][] M5 = mulMatx(addMatx(A11, A12), B22);
        int[][] M6 = mulMatx(subMatx(A21, A11), addMatx(B11, B12));
        int[][] M7 = mulMatx(subMatx(A12, A22), addMatx(B21, B22));

        // Calculate the four quadrants of the result matrix C
        int[][] C11 = addMatx(subMatx(addMatx(M1, M4), M5), M7);
        int[][] C12 = addMatx(M3, M5);
        int[][] C21 = addMatx(M2, M4);
        int[][] C22 = addMatx(subMatx(addMatx(M1, M3), M2), M6);

        // Join the four quadrants to form the result matrix C
        joinMatx(C11, C, 0, 0);
        joinMatx(C12, C, 0, n / 2);
        joinMatx(C21, C, n / 2, 0);
        joinMatx(C22, C, n / 2, n / 2);
        
    }    
    return C;    
}

public static void splitMatx(int[][] mainMat, int[][] subMat, int s, int e){
    for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
        for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
          mainMat[i1][j1] = subMat[i2][j2];
        }
    }
}

public static void joinMatx(int[][] mainMat, int[][] subMat, int s, int e){
    for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
        for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
          subMat[i2][j2] = mainMat[i1][j1];
        }
    }
}

public static int[][] addMatx(int[][] A, int[][] B){
    int[][] C = new int[A.length][B.length];
    for(int i=0;i<A.length;i++) {
        for(int j=0;j<A.length;j++) {
            C[i][j] = A[i][j] + B[i][j];
        }
    }
    return C;
}

public static int[][] subMatx(int[][] A, int[][] B){
    int[][] C = new int[A.length][B.length];
    for(int i=0;i<A.length;i++) {
        for(int j=0;j<A.length;j++) {
            C[i][j] = A[i][j] - B[i][j];
        }
    }
    return C;
}

public static void printMatx(int[][] Mat) {
    for(int[] row : Mat){
        System.out.print(Arrays.toString(row));
    }
}

public static void main(String[] args) {
    int[][] matA = {
        {2, 4, 5, 9},
        {2, 12, 7, 5},
        {15, 8, 2, 5},
        {14, 4, 2, 5}
    };
    int[][] matB = {
        {7, 4, 2, 24},
        {15 , 7, 2, 6},
        {10 , 12, 2, 4},
        {9, 2, 6, 3}
    };
    printMatx(mulMatx(matA, matB));
}

}

它抛出错误:线程“main”中的异常 java.lang.ArrayIndexOutOfBoundsException:索引 2 超出长度 2 的范围

matrix-multiplication strassen
1个回答
0
投票

您的代码中存在一些小故障。 调试器是捕获此类错误的好工具。

我注释掉了原始行以显示更改。

import java.util.*;

public class Main {

    public static int[][] mulMatx(int[][] A, int[][] B) {
        int n = A.length;
        //  int[][] C = new int[1][1];
        int[][] C = new int[n][n];

        if (n==1) {
            C[0][0] = A[0][0] * B[0][0];
        } else {

            // Sub Matrixes of Matrix A
            int[][] A11 = new int[n/2][n/2];
            int[][] A12 = new int[n/2][n/2];
            int[][] A21 = new int[n/2][n/2];
            int[][] A22 = new int[n/2][n/2];

            // Sub Matrixes of Matrix B
            int[][] B11 = new int[n/2][n/2];
            int[][] B12 = new int[n/2][n/2];
            int[][] B21 = new int[n/2][n/2];
            int[][] B22 = new int[n/2][n/2];

            // Split Matrixes of Matrix A
            splitMatx(A, A11, 0, 0);
            splitMatx(A, A12, 0, n/2);
            splitMatx(A, A21, n/2, 0);
            splitMatx(A, A22, n/2, n/2);

            // Split Matrixes of Matrix B
            splitMatx(B, B11, 0, 0);
            splitMatx(B, B12, 0, n/2);
            splitMatx(B, B21, n/2, 0);
            splitMatx(B, B22, n/2, n/2);

            // Calculate the seven products (M1 to M7) using Strassen recursive calls
            int[][] M1 = mulMatx(addMatx(A11, A22), addMatx(B11, B22));
            int[][] M2 = mulMatx(addMatx(A21, A22), B11);
            int[][] M3 = mulMatx(A11, subMatx(B12, B22));
            int[][] M4 = mulMatx(A22, subMatx(B21, B11));
            int[][] M5 = mulMatx(addMatx(A11, A12), B22);
            int[][] M6 = mulMatx(subMatx(A21, A11), addMatx(B11, B12));
            int[][] M7 = mulMatx(subMatx(A12, A22), addMatx(B21, B22));

            // Calculate the four quadrants of the result matrix C
            int[][] C11 = addMatx(subMatx(addMatx(M1, M4), M5), M7);
            int[][] C12 = addMatx(M3, M5);
            int[][] C21 = addMatx(M2, M4);
            int[][] C22 = addMatx(subMatx(addMatx(M1, M3), M2), M6);

            // Join the four quadrants to form the result matrix C
            joinMatx(C11, C, 0, 0);
            joinMatx(C12, C, 0, n / 2);
            joinMatx(C21, C, n / 2, 0);
            joinMatx(C22, C, n / 2, n / 2);

        }
        return C;
    }

    public static void splitMatx(int[][] mainMat, int[][] subMat, int s, int e){
        for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
            for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
                //  mainMat[i1][j1] = subMat[i2][j2];
                subMat[i1][j1] = mainMat[i2][j2];
            }
        }
    }

//     public static void joinMatx(int[][] mainMat, int[][] subMat, int s, int e){
       public static void joinMatx(int[][] subMat, int[][] mainMat, int s, int e){
        for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
            for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
                //  subMat[i2][j2] = mainMat[i1][j1];
                mainMat[i2][j2] = subMat[i1][j1];
            }
        }
    }

    public static int[][] addMatx(int[][] A, int[][] B){
        int[][] C = new int[A.length][B.length];
        for(int i=0;i<A.length;i++) {
            for(int j=0;j<A.length;j++) {
                C[i][j] = A[i][j] + B[i][j];
            }
        }
        return C;
    }

    public static int[][] subMatx(int[][] A, int[][] B){
        int[][] C = new int[A.length][B.length];
        for(int i=0;i<A.length;i++) {
            for(int j=0;j<A.length;j++) {
                C[i][j] = A[i][j] - B[i][j];
            }
        }
        return C;
    }

    public static void printMatx(int[][] Mat) {
        for(int[] row : Mat){
            //  System.out.print(Arrays.toString(row));
            System.out.print(Arrays.toString(row) + "\n");
        }
    }

    public static void main(String[] args) {
        int[][] matA = {
                {2, 4, 5, 9},
                {2, 12, 7, 5},
                {15, 8, 2, 5},
                {14, 4, 2, 5}
        };
        int[][] matB = {
                {7, 4, 2, 24},
                {15 , 7, 2, 6},
                {10 , 12, 2, 4},
                {9, 2, 6, 3}
        };
        printMatx(mulMatx(matA, matB));
    }
}
© www.soinside.com 2019 - 2024. All rights reserved.