为什么我的矩阵乘法算法使用 2D C 样式数组比使用 1D std::array 更快?

问题描述 投票:0回答:1

我最近一直在编写和优化线性代数的东西,只是出于教育目的,并尝试编写一些代码来查看存储矩阵的不同方法之间的差异。 据我所知,使用 std::array 和 C 风格数组时,运行时应该几乎没有区别,并且我在代码中看不到 mult1 和 mult2 的时间复杂度之间有任何明显的差异:

#include <iostream>
#include <array>
#include <random>
#include <chrono>
using namespace std;
void setRandomValues(std::array<float, 10000>& arr1D) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis(0, 1);

    for (auto& element : arr1D) {
        element = static_cast<float>(dis(gen));
    }
}
void setRandomValues(float(&array)[100][100]) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis(0, 1);

    for (int i = 0; i < 100; ++i) {
        for (int j = 0; j < 100; ++j) {
            array[i][j] = static_cast<float>(dis(gen));
        }
    }
}
void mult1(array<float, 10000>& mat1, array<float, 10000>& mat2, array<float, 10000>& mat3)
{
    int rowsize = 100;
    for (size_t row = 0; row < 100; row++)
    {
        for (size_t col = 0; col < 100; col++)
        {
            float sum = 0;
            for (size_t k = 0; k < 100; k++)
            {
                int index1 = (row * rowsize) + k;
                int index2 = (k * rowsize) + col;
                sum += mat1[index1] * mat2[index2];
            }
            int index3 = (row * rowsize) + col;
            mat3[index3] = sum;
        }
    }
}
void mult2(float(&mat1)[100][100], float(&mat2)[100][100], float(&mat3)[100][100])
{
    for (size_t i = 0; i < 100; i++)
    {
        for (size_t x = 0; x < 100; x++)
        {
            float sum = 0;
            for (size_t k = 0; k < 100; k++)
            {
                sum += mat1[i][k] * mat2[k][x];
            }
            mat3[i][x] = sum;
        }
    }
}

int main()
{
    float mat1[100][100];
    float mat2[100][100];
    float mat3[100][100];
    std::array < float, 10000> mat4;
    std::array < float, 10000> mat5;
    std::array < float, 10000> mat6;
    setRandomValues(mat1);
    setRandomValues(mat2);
    setRandomValues(mat4);
    setRandomValues(mat5);

    auto start = std::chrono::high_resolution_clock::now();

    for (int i = 0; i < 100000; i++)
    {
        mult1(mat4, mat5, mat6);
    }

    auto stop = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);

    std::cout << "Using std::array in row major order " << duration.count() << std::endl;

    start = std::chrono::high_resolution_clock::now();

    for (int i = 0; i < 100000; i++)
    {
        mult2(mat1, mat2, mat3);
    }

    stop = std::chrono::high_resolution_clock::now();
    duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);

    std::cout << "Using C style nested array " << duration.count() << std::endl;

    return 0;

}

那么为什么它又回来了:

Using std::array in row-major order 711461
Using C style nested array 53742

我错过了一些明显的事情吗?

c++ linear-algebra
1个回答
0
投票

我错过了一些明显的事情吗?

虽然我不认为这应该解释这么大的差异,但你肯定错过了一件明显的事情:使用

int
作为索引是不好的。

例如,MSVC 可能(取决于您编译代码的方式)生成如下内容:

$LL10@mult1:
        imul    eax, r8d, 100                     ; 00000064H
        add     eax, r9d
        movsxd  rdx, eax
        lea     eax, DWORD PTR [r10+r8]
        movsxd  rcx, eax
        inc     r8
        movss   xmm0, DWORD PTR [r11+rdx*4]
        mulss   xmm0, DWORD PTR [rbx+rcx*4]
        addss   xmm1, xmm0
        cmp     r8, 100                             ; 00000064H
        jb      SHORT $LL10@mult1

执行符号扩展的

movsxd
并不是这里最糟糕的事情。最糟糕的是代码根本就是这种形式。使用
size_t
索引(无论如何你几乎总是应该使用它),MSVC 删除了几乎所有索引算术 并且 展开了循环:

$LL10@mult1:
        movss   xmm1, DWORD PTR [rcx-400]
        mulss   xmm1, DWORD PTR [rax+4]
        movss   xmm0, DWORD PTR [rcx-800]
        mulss   xmm0, DWORD PTR [rax]
        addss   xmm0, xmm2
        movaps  xmm2, xmm1
        movss   xmm1, DWORD PTR [rcx+400]
        mulss   xmm1, DWORD PTR [rax+12]
        addss   xmm2, xmm0
        movss   xmm0, DWORD PTR [rax+8]
        mulss   xmm0, DWORD PTR [rcx]
        addss   xmm2, xmm0
        movss   xmm0, DWORD PTR [rcx+800]
        mulss   xmm0, DWORD PTR [rax+16]
        addss   xmm2, xmm1
        movss   xmm1, DWORD PTR [rcx+1200]
        mulss   xmm1, DWORD PTR [rax+20]
        addss   xmm2, xmm0
        movss   xmm0, DWORD PTR [rcx+1600]
        mulss   xmm0, DWORD PTR [rax+24]
        addss   xmm2, xmm1
        movss   xmm1, DWORD PTR [rcx+2000]
        mulss   xmm1, DWORD PTR [rax+28]
        addss   xmm2, xmm0
        movss   xmm0, DWORD PTR [rax+32]
        mulss   xmm0, DWORD PTR [rcx+2400]
        addss   xmm2, xmm1
        movss   xmm1, DWORD PTR [rax+36]
        mulss   xmm1, DWORD PTR [rcx+2800]
        add     rax, 40                             ; 00000028H
        add     rcx, 4000               ; 00000fa0H
        addss   xmm2, xmm0
        addss   xmm2, xmm1
        sub     r8, 1
        jne     $LL10@mult1

遗憾的是没有自动矢量化。这仍然是标量代码,只是展开并且开销较少。

© www.soinside.com 2019 - 2024. All rights reserved.