我想高效地理解Strassen的矩阵乘法算法。
如果你能帮我解决这个问题。 这是我写的代码:
import java.util.*;
public class Main {
public static int[][] mulMatx(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[1][1];
if (n==1) {
C[0][0] = A[0][0] * B[0][0];
} else {
// Sub Matrixes of Matrix A
int[][] A11 = new int[n/2][n/2];
int[][] A12 = new int[n/2][n/2];
int[][] A21 = new int[n/2][n/2];
int[][] A22 = new int[n/2][n/2];
// Sub Matrixes of Matrix B
int[][] B11 = new int[n/2][n/2];
int[][] B12 = new int[n/2][n/2];
int[][] B21 = new int[n/2][n/2];
int[][] B22 = new int[n/2][n/2];
// Split Matrixes of Matrix A
splitMatx(A, A11, 0, 0);
splitMatx(A, A12, 0, n/2);
splitMatx(A, A21, n/2, 0);
splitMatx(A, A22, n/2, n/2);
// Split Matrixes of Matrix B
splitMatx(B, B11, 0, 0);
splitMatx(B, B12, 0, n/2);
splitMatx(B, B21, n/2, 0);
splitMatx(B, B22, n/2, n/2);
// Calculate the seven products (M1 to M7) using Strassen recursive calls
int[][] M1 = mulMatx(addMatx(A11, A22), addMatx(B11, B22));
int[][] M2 = mulMatx(addMatx(A21, A22), B11);
int[][] M3 = mulMatx(A11, subMatx(B12, B22));
int[][] M4 = mulMatx(A22, subMatx(B21, B11));
int[][] M5 = mulMatx(addMatx(A11, A12), B22);
int[][] M6 = mulMatx(subMatx(A21, A11), addMatx(B11, B12));
int[][] M7 = mulMatx(subMatx(A12, A22), addMatx(B21, B22));
// Calculate the four quadrants of the result matrix C
int[][] C11 = addMatx(subMatx(addMatx(M1, M4), M5), M7);
int[][] C12 = addMatx(M3, M5);
int[][] C21 = addMatx(M2, M4);
int[][] C22 = addMatx(subMatx(addMatx(M1, M3), M2), M6);
// Join the four quadrants to form the result matrix C
joinMatx(C11, C, 0, 0);
joinMatx(C12, C, 0, n / 2);
joinMatx(C21, C, n / 2, 0);
joinMatx(C22, C, n / 2, n / 2);
}
return C;
}
public static void splitMatx(int[][] mainMat, int[][] subMat, int s, int e){
for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
mainMat[i1][j1] = subMat[i2][j2];
}
}
}
public static void joinMatx(int[][] mainMat, int[][] subMat, int s, int e){
for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
subMat[i2][j2] = mainMat[i1][j1];
}
}
}
public static int[][] addMatx(int[][] A, int[][] B){
int[][] C = new int[A.length][B.length];
for(int i=0;i<A.length;i++) {
for(int j=0;j<A.length;j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
public static int[][] subMatx(int[][] A, int[][] B){
int[][] C = new int[A.length][B.length];
for(int i=0;i<A.length;i++) {
for(int j=0;j<A.length;j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
public static void printMatx(int[][] Mat) {
for(int[] row : Mat){
System.out.print(Arrays.toString(row));
}
}
public static void main(String[] args) {
int[][] matA = {
{2, 4, 5, 9},
{2, 12, 7, 5},
{15, 8, 2, 5},
{14, 4, 2, 5}
};
int[][] matB = {
{7, 4, 2, 24},
{15 , 7, 2, 6},
{10 , 12, 2, 4},
{9, 2, 6, 3}
};
printMatx(mulMatx(matA, matB));
}
}
它抛出错误:线程“main”中的异常 java.lang.ArrayIndexOutOfBoundsException:索引 2 超出长度 2 的范围
您的代码中存在一些小故障。 调试器是捕获此类错误的好工具。
我注释掉了原始行以显示更改。
import java.util.*;
public class Main {
public static int[][] mulMatx(int[][] A, int[][] B) {
int n = A.length;
// int[][] C = new int[1][1];
int[][] C = new int[n][n];
if (n==1) {
C[0][0] = A[0][0] * B[0][0];
} else {
// Sub Matrixes of Matrix A
int[][] A11 = new int[n/2][n/2];
int[][] A12 = new int[n/2][n/2];
int[][] A21 = new int[n/2][n/2];
int[][] A22 = new int[n/2][n/2];
// Sub Matrixes of Matrix B
int[][] B11 = new int[n/2][n/2];
int[][] B12 = new int[n/2][n/2];
int[][] B21 = new int[n/2][n/2];
int[][] B22 = new int[n/2][n/2];
// Split Matrixes of Matrix A
splitMatx(A, A11, 0, 0);
splitMatx(A, A12, 0, n/2);
splitMatx(A, A21, n/2, 0);
splitMatx(A, A22, n/2, n/2);
// Split Matrixes of Matrix B
splitMatx(B, B11, 0, 0);
splitMatx(B, B12, 0, n/2);
splitMatx(B, B21, n/2, 0);
splitMatx(B, B22, n/2, n/2);
// Calculate the seven products (M1 to M7) using Strassen recursive calls
int[][] M1 = mulMatx(addMatx(A11, A22), addMatx(B11, B22));
int[][] M2 = mulMatx(addMatx(A21, A22), B11);
int[][] M3 = mulMatx(A11, subMatx(B12, B22));
int[][] M4 = mulMatx(A22, subMatx(B21, B11));
int[][] M5 = mulMatx(addMatx(A11, A12), B22);
int[][] M6 = mulMatx(subMatx(A21, A11), addMatx(B11, B12));
int[][] M7 = mulMatx(subMatx(A12, A22), addMatx(B21, B22));
// Calculate the four quadrants of the result matrix C
int[][] C11 = addMatx(subMatx(addMatx(M1, M4), M5), M7);
int[][] C12 = addMatx(M3, M5);
int[][] C21 = addMatx(M2, M4);
int[][] C22 = addMatx(subMatx(addMatx(M1, M3), M2), M6);
// Join the four quadrants to form the result matrix C
joinMatx(C11, C, 0, 0);
joinMatx(C12, C, 0, n / 2);
joinMatx(C21, C, n / 2, 0);
joinMatx(C22, C, n / 2, n / 2);
}
return C;
}
public static void splitMatx(int[][] mainMat, int[][] subMat, int s, int e){
for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
// mainMat[i1][j1] = subMat[i2][j2];
subMat[i1][j1] = mainMat[i2][j2];
}
}
}
// public static void joinMatx(int[][] mainMat, int[][] subMat, int s, int e){
public static void joinMatx(int[][] subMat, int[][] mainMat, int s, int e){
for(int i1=0,i2=s;i1<subMat.length;i1++,i2++){
for(int j1=0,j2=e;j1<subMat.length;j1++,j2++){
// subMat[i2][j2] = mainMat[i1][j1];
mainMat[i2][j2] = subMat[i1][j1];
}
}
}
public static int[][] addMatx(int[][] A, int[][] B){
int[][] C = new int[A.length][B.length];
for(int i=0;i<A.length;i++) {
for(int j=0;j<A.length;j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
public static int[][] subMatx(int[][] A, int[][] B){
int[][] C = new int[A.length][B.length];
for(int i=0;i<A.length;i++) {
for(int j=0;j<A.length;j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
public static void printMatx(int[][] Mat) {
for(int[] row : Mat){
// System.out.print(Arrays.toString(row));
System.out.print(Arrays.toString(row) + "\n");
}
}
public static void main(String[] args) {
int[][] matA = {
{2, 4, 5, 9},
{2, 12, 7, 5},
{15, 8, 2, 5},
{14, 4, 2, 5}
};
int[][] matB = {
{7, 4, 2, 24},
{15 , 7, 2, 6},
{10 , 12, 2, 4},
{9, 2, 6, 3}
};
printMatx(mulMatx(matA, matB));
}
}