我正在尝试编写一个程序,该程序将使用 Strassen 方法使用以列主要顺序表示为一维数组的二维数组进行矩阵乘法。这是我现在所拥有的。
这是计算的主要方法。
vector<double> strassen_mult(int n, vector<double> A, vector<double> B) {
vector<double> C(n*n);
if (n == 1) {
C.push_back(A[0] * B[0]);
} else {
vector<double> S1 = matrix_partitioner(n/2, 1, B) - matrix_partitioner(n/2, 3, B);
vector<double> S2 = matrix_partitioner(n/2, 0, A) + matrix_partitioner(n/2, 1, A);
vector<double> S3 = matrix_partitioner(n/2, 2, A) + matrix_partitioner(n/2, 3, A);
vector<double> S4 = matrix_partitioner(n/2, 2, B) - matrix_partitioner(n/2, 0, B);
vector<double> S5 = matrix_partitioner(n/2, 0, A) + matrix_partitioner(n/2, 3, A);
vector<double> S6 = matrix_partitioner(n/2, 0, B) + matrix_partitioner(n/2, 3, B);
vector<double> S7 = matrix_partitioner(n/2, 1, A) - matrix_partitioner(n/2, 3, A);
vector<double> S8 = matrix_partitioner(n/2, 2, B) + matrix_partitioner(n/2, 3, B);
vector<double> S9 = matrix_partitioner(n/2, 0, A) - matrix_partitioner(n/2, 2, A);
vector<double> S10 = matrix_partitioner(n/2, 0, B) + matrix_partitioner(n/2, 1, B);
vector<double> P1 = strassen_mult(n/2, matrix_partitioner(n/2, 0 ,A), S1);
vector<double> P2 = strassen_mult(n/2, S2, matrix_partitioner(n/2, 3, B));
vector<double> P3 = strassen_mult(n/2, S3, matrix_partitioner(n/2, 0, B));
vector<double> P4 = strassen_mult(n/2, matrix_partitioner(n/2, 3 ,A), S4);
vector<double> P5 = strassen_mult(n/2, S5, S6);
vector<double> P6 = strassen_mult(n/2, S7, S8);
vector<double> P7 = strassen_mult(n/2, S9, S10);
C = equals(C, P5 + P4 - P2 + P6);
C = equals(C, P1 + P2);
C = equals(C, P3 + P4);
C = equals(C, P5 + P1 - P3 - P7);
}
return C;
}
Matrix partitioner 从给定的矩阵中获取一个特定的象限。
vector<double> matrix_partitioner(int n, int section, vector<double> A) {
vector<double> C(n*n);
int start_i, start_j, tmp_j;
if (section == 0) {
start_i = 0;
tmp_j = 0;
}
else if (section == 1) {
start_i = 0;
tmp_j = n;
}
else if (section == 2) {
start_i = n;
tmp_j = 0;
}
else if (section == 3) {
start_i = n;
tmp_j = n;
}
for (int i = 0; i < n; i++) {
start_j = tmp_j;
for (int j = 0; j < n; j++) {
C[i+(j*n)] = A[start_i+(start_j*(n*2))];
start_j++;
}
start_i++;
}
return C;
}
Equals将两个矩阵相加的结果放入C
vector<double> equals(vector<double> A, vector<double> B) {
for (int i = 0; i < B.size(); i++) {
A.push_back(B[i]);
}
return A;
}
我还重载了“+”和“-”运算符,以便更轻松地添加和减去矩阵。
这些是我得到的结果(我有迭代方法的结果可以比较,它们都使用相同的打印方法):
迭代 | 250 260 270 280 | | 618 644 670 696 | | 986 1028 1070 1112 | | 1354 1412 1470 1528 | 施特拉森 | 0 0 0 0 | | 0 0 0 0 | | 0 0 0 0 | | 0 0 0 0 |
显然我的结果不正确。有人可以帮我修复这段代码吗?
我试过移动东西,但没有找到好的方法