矩阵乘法(Strassens 方法)C++

问题描述 投票:0回答:0

我正在尝试编写一个程序,该程序将使用 Strassen 方法使用以列主要顺序表示为一维数组的二维数组进行矩阵乘法。这是我现在所拥有的。

这是计算的主要方法。

vector<double> strassen_mult(int n, vector<double> A, vector<double> B) {
    vector<double> C(n*n);
    if (n == 1) {
        C.push_back(A[0] * B[0]);
    } else {
        vector<double> S1 = matrix_partitioner(n/2, 1, B) - matrix_partitioner(n/2, 3, B);
        vector<double> S2 = matrix_partitioner(n/2, 0, A) + matrix_partitioner(n/2, 1, A);
        vector<double> S3 = matrix_partitioner(n/2, 2, A) + matrix_partitioner(n/2, 3, A);
        vector<double> S4 = matrix_partitioner(n/2, 2, B) - matrix_partitioner(n/2, 0, B);
        vector<double> S5 = matrix_partitioner(n/2, 0, A) + matrix_partitioner(n/2, 3, A);
        vector<double> S6 = matrix_partitioner(n/2, 0, B) + matrix_partitioner(n/2, 3, B);
        vector<double> S7 = matrix_partitioner(n/2, 1, A) - matrix_partitioner(n/2, 3, A);
        vector<double> S8 = matrix_partitioner(n/2, 2, B) + matrix_partitioner(n/2, 3, B);
        vector<double> S9 = matrix_partitioner(n/2, 0, A) - matrix_partitioner(n/2, 2, A);
        vector<double> S10 = matrix_partitioner(n/2, 0, B) + matrix_partitioner(n/2, 1, B);
        vector<double> P1 = strassen_mult(n/2, matrix_partitioner(n/2, 0 ,A), S1);
        vector<double> P2 = strassen_mult(n/2, S2, matrix_partitioner(n/2, 3, B));
        vector<double> P3 = strassen_mult(n/2, S3, matrix_partitioner(n/2, 0, B));
        vector<double> P4 = strassen_mult(n/2, matrix_partitioner(n/2, 3 ,A), S4);
        vector<double> P5 = strassen_mult(n/2, S5, S6);
        vector<double> P6 = strassen_mult(n/2, S7, S8);
        vector<double> P7 = strassen_mult(n/2, S9, S10);
        C = equals(C, P5 + P4 - P2 + P6);
        C = equals(C, P1 + P2);
        C = equals(C, P3 + P4);
        C = equals(C, P5 + P1 - P3 - P7);
    }
    return C;
}

Matrix partitioner 从给定的矩阵中获取一个特定的象限。

vector<double> matrix_partitioner(int n, int section, vector<double> A) {
    vector<double> C(n*n);
    int start_i, start_j, tmp_j;
    if (section == 0) {
        start_i = 0;
        tmp_j = 0;
    }
    else if (section == 1) {
        start_i = 0;
        tmp_j = n;
    }
    else if (section == 2) {
        start_i = n;
        tmp_j = 0;
    }
    else if (section == 3) {
        start_i = n;
        tmp_j = n;
    }
    for (int i = 0; i < n; i++) {
        start_j = tmp_j;
        for (int j = 0; j < n; j++) {
            C[i+(j*n)] = A[start_i+(start_j*(n*2))];
            start_j++;
        }
        start_i++;
    }
    return C;
}

Equals将两个矩阵相加的结果放入C

vector<double> equals(vector<double> A, vector<double> B) {
    for (int i = 0; i < B.size(); i++) {
        A.push_back(B[i]);
    }
    return A;
}

我还重载了“+”和“-”运算符,以便更轻松地添加和减去矩阵。

这些是我得到的结果(我有迭代方法的结果可以比较,它们都使用相同的打印方法):

迭代 | 250 260 270 280 | | 618 644 670 696 | | 986 1028 1070 1112 | | 1354 1412 1470 1528 | 施特拉森 | 0 0 0 0 | | 0 0 0 0 | | 0 0 0 0 | | 0 0 0 0 |

显然我的结果不正确。有人可以帮我修复这段代码吗?

我试过移动东西,但没有找到好的方法

c++ recursion matrix matrix-multiplication strassen
© www.soinside.com 2019 - 2024. All rights reserved.