x86 内在:优化复杂浮点数的矩阵乘法

问题描述 投票:0回答:1

以下代码用于将复数浮点矩阵(单独的 Real、Imag)与浮点矩阵相乘。

我非常确定,由于加载、存储和乘法的延迟,可以通过重新排序代码来优化它。 您能告诉我是否存在如何优化代码来处理这种延迟的规则吗?

/***************************************************************************************/
void CVector::MatrixMultiply(float* pReA, float* pImA,
                            float* pTranB,
                            float* pOutRe, float* pOutIm,
                            uint32_t RowsA, uint32_t ColsA,
                            uint32_t RowsB, uint32_t ColsB)
{
    float *pSrcReA;
    float* pSrcImA;
    float* pSrcB;
    float* pDstRe = pOutRe;
    float* pDstIm = pOutIm;
    float* pRowReA, * pRowImA;

    __m256 ReSum, ImSum, VecReA, VecImA;
    __m256 *pAvec, *pBvec;
    __m256 VecA, VecB;
    __m128 Low, High, Sum128;
    __m128 Zero128 = _mm_set_ps1(0);

    uint32_t Offset;

    for (int i = 0; i < RowsA; i++)
    {
        Offset = ColsA * i;
        pSrcReA = pReA + Offset;
        pSrcImA = pImA + Offset;
        for (int j = 0; j < ColsB; j++)
        {
            ReSum = _mm256_set1_ps(0);
            ImSum = ReSum;
            pRowReA = pSrcReA;
            pRowImA = pSrcImA;
            pSrcB = pTranB + RowsB * j;

            for (int k = 0; k < (ColsA >> 3); k++)
            {
                VecReA = _mm256_load_ps((float*)pRowReA);
                VecImA = _mm256_load_ps((float*)pRowImA);
                VecB = _mm256_load_ps((float*)pSrcB);

                ReSum = _mm256_fmadd_ps (VecReA, VecB, ReSum);
                ImSum = _mm256_fmadd_ps (VecImA, VecB, ImSum);

                pRowReA += 8;
                pRowImA += 8;
            }

            Low = _mm256_extractf128_ps(ReSum, 0);
            High = _mm256_extractf128_ps(ReSum, 1);
            Sum128 = _mm_add_ps(Low, High);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            *pDstRe = _mm_cvtss_f32(Sum128);

            Low = _mm256_extractf128_ps(ImSum, 0);
            High = _mm256_extractf128_ps(ImSum, 1);
            Sum128 = _mm_add_ps(Low, High);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            *pDstIm = _mm_cvtss_f32(Sum128);

            pDstRe++;
            pDstIm++;
        }
    }
}
c x86 matrix-multiplication complex-numbers avx
1个回答
0
投票

代码最大的性能问题是(在大多数 CPU 上)

fmadd
有 4-5 个周期的延迟,但吞吐量倒数为 0.5(即,可以同时执行两个独立的 FMA)——来源: uops.info.

为了获得完整的吞吐量,您需要在内循环内执行 8 个(或在某些 CPU 上 10 个)独立的 FMA 操作。例如,有8个独立的

ReSum0..3
ImSum0..3
累加器,并通过8个
{VecReA, VecImA} * VecB0..3
乘积累加到它们。我不会写出来,因为我不完全理解你的代码,例如,为什么你不在
pSrcB
循环中增加
k
?您确定
ColsA==RowsB
并且它们是 8 的倍数吗?

© www.soinside.com 2019 - 2024. All rights reserved.