为什么不能在 std::span 上对这个循环进行 clang 向量化,将结果写入 std::array ?

问题描述 投票:0回答:1

为什么 clang 17.0.1 不会矢量化以下函数中的循环:

void adapt(std::span<const F, N + 1> signal)
{
    F true_val = signal.back();
    F y = dot_prod<F, N>(&signal[0], w.begin());
    F e = true_val - y;
    F dot = dot_prod<F, N>(&signal[0], &signal[0]);
    F nu = mu / (eps + dot);
    for (size_t i = 0; i < N; i++)
    {
        w[i] += nu * e * signal[i];
    }
}

不存在附带依赖性或浮点关联性问题,GCC 13.2 对其进行矢量化没有任何问题。

这是编译器资源管理器上完整代码的链接。

其背景是我正在尝试优化我的代码以使用矢量化点积。通常,由于浮点关联性问题,

std::inner_product
会发出标量实现,除非您使用
-ffast-math
。但是,我只希望
-ffast-math
应用于单个函数,并且我一直在寻找一种可移植的方法来为 clang 和 GCC 执行此操作。在查看输出时,我注意到 clang 不会矢量化另一个循环。

这是完整的 C++:

#include <cstddef>
#include <span>
#include <numeric>
#include <cassert>

#pragma float_control(precise, off, push)
template <typename F, size_t N> 
__attribute__((optimize("-ffast-math"))) 
constexpr inline F dot_prod(const F *a, const F *b)
{
    F acc = 0.0f;
    for(size_t i = 0; i < N; i++)
        acc += a[i] * b[i];
    return acc;
}
#pragma float_control(pop)

template <typename F, size_t N> class NlmsFilter
{
    static constexpr F mu = 0.5;
    static constexpr F eps = 1.0;
    std::array<F, N> w = {};

  public:
    F predict(std::span<const F> signal)
    {
        assert(signal.size() <= N + 1);
        if (signal.size() == N + 1)
        {
            auto signal_static = signal.template subspan<0, N + 1>();
            adapt(signal_static);
            return dot_prod<F, N>(&signal_static[1], w.begin());
        }
        else if (signal.size() == 0)
        {
            return 0.0f;
        }
        else
        {
            return signal.back();
        }
    }
    void adapt(std::span<const F, N + 1> signal)
    {
        F true_val = signal.back();
        F y = dot_prod<F, N>(&signal[0], w.begin());
        F e = true_val - y;
        F dot = dot_prod<F, N>(&signal[0], &signal[0]);
        F nu = mu / (eps + dot);
        for (size_t i = 0; i < N; i++)
        {
            w[i] += nu * e * signal[i];
        }
    }
};

template class NlmsFilter<float, 32>;

这是 clang 的汇编输出:

NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>): # @NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>)
        push    r14
        push    rbx
        push    rax
        cmp     rdx, 34
        jae     .LBB0_6
        test    rdx, rdx
        je      .LBB0_2
        mov     rbx, rsi
        cmp     rdx, 33
        jne     .LBB0_4
        mov     r14, rdi
        mov     rsi, rbx
        call    NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>)
        add     rbx, 4
        mov     rdi, rbx
        mov     rsi, r14
        add     rsp, 8
        pop     rbx
        pop     r14
        jmp     float dot_prod<float, 32ul>(float const*, float const*)   # TAILCALL
.LBB0_2:
        vxorps  xmm0, xmm0, xmm0
        add     rsp, 8
        pop     rbx
        pop     r14
        ret
.LBB0_4:
        vmovss  xmm0, dword ptr [rbx + 4*rdx - 4] # xmm0 = mem[0],zero,zero,zero
        add     rsp, 8
        pop     rbx
        pop     r14
        ret
.LBB0_6:
        mov     edi, offset .L.str
        mov     esi, offset .L.str.1
        mov     ecx, offset .L__PRETTY_FUNCTION__.NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>)
        mov     edx, 27
        call    __assert_fail
.LCPI1_0:
        .long   0x3f800000                      # float 1
.LCPI1_1:
        .long   0x3f000000                      # float 0.5
NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>): # @NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>)
        push    r14
        push    rbx
        push    rax
        mov     r14, rsi
        mov     rbx, rdi
        vmovss  xmm0, dword ptr [rsi + 128]     # xmm0 = mem[0],zero,zero,zero
        vmovss  dword ptr [rsp + 4], xmm0       # 4-byte Spill
        mov     rdi, rsi
        mov     rsi, rbx
        call    float dot_prod<float, 32ul>(float const*, float const*)
        vmovss  xmm1, dword ptr [rsp + 4]       # 4-byte Reload
        vsubss  xmm0, xmm1, xmm0
        vmovss  dword ptr [rsp + 4], xmm0       # 4-byte Spill
        mov     rdi, r14
        mov     rsi, r14
        call    float dot_prod<float, 32ul>(float const*, float const*)
        vaddss  xmm0, xmm0, dword ptr [rip + .LCPI1_0]
        vmovss  xmm1, dword ptr [rip + .LCPI1_1] # xmm1 = mem[0],zero,zero,zero
        vdivss  xmm0, xmm1, xmm0
        vmulss  xmm0, xmm0, dword ptr [rsp + 4] # 4-byte Folded Reload
        vmulss  xmm1, xmm0, dword ptr [r14]
        vaddss  xmm1, xmm1, dword ptr [rbx]
        vmovss  dword ptr [rbx], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 4]
        vaddss  xmm1, xmm1, dword ptr [rbx + 4]
        vmovss  dword ptr [rbx + 4], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 8]
        vaddss  xmm1, xmm1, dword ptr [rbx + 8]
        vmovss  dword ptr [rbx + 8], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 12]
        vaddss  xmm1, xmm1, dword ptr [rbx + 12]
        vmovss  dword ptr [rbx + 12], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 16]
        vaddss  xmm1, xmm1, dword ptr [rbx + 16]
        vmovss  dword ptr [rbx + 16], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 20]
        vaddss  xmm1, xmm1, dword ptr [rbx + 20]
        vmovss  dword ptr [rbx + 20], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 24]
        vaddss  xmm1, xmm1, dword ptr [rbx + 24]
        vmovss  dword ptr [rbx + 24], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 28]
        vaddss  xmm1, xmm1, dword ptr [rbx + 28]
        vmovss  dword ptr [rbx + 28], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 32]
        vaddss  xmm1, xmm1, dword ptr [rbx + 32]
        vmovss  dword ptr [rbx + 32], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 36]
        vaddss  xmm1, xmm1, dword ptr [rbx + 36]
        vmovss  dword ptr [rbx + 36], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 40]
        vaddss  xmm1, xmm1, dword ptr [rbx + 40]
        vmovss  dword ptr [rbx + 40], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 44]
        vaddss  xmm1, xmm1, dword ptr [rbx + 44]
        vmovss  dword ptr [rbx + 44], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 48]
        vaddss  xmm1, xmm1, dword ptr [rbx + 48]
        vmovss  dword ptr [rbx + 48], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 52]
        vaddss  xmm1, xmm1, dword ptr [rbx + 52]
        vmovss  dword ptr [rbx + 52], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 56]
        vaddss  xmm1, xmm1, dword ptr [rbx + 56]
        vmovss  dword ptr [rbx + 56], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 60]
        vaddss  xmm1, xmm1, dword ptr [rbx + 60]
        vmovss  dword ptr [rbx + 60], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 64]
        vaddss  xmm1, xmm1, dword ptr [rbx + 64]
        vmovss  dword ptr [rbx + 64], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 68]
        vaddss  xmm1, xmm1, dword ptr [rbx + 68]
        vmovss  dword ptr [rbx + 68], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 72]
        vaddss  xmm1, xmm1, dword ptr [rbx + 72]
        vmovss  dword ptr [rbx + 72], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 76]
        vaddss  xmm1, xmm1, dword ptr [rbx + 76]
        vmovss  dword ptr [rbx + 76], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 80]
        vaddss  xmm1, xmm1, dword ptr [rbx + 80]
        vmovss  dword ptr [rbx + 80], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 84]
        vaddss  xmm1, xmm1, dword ptr [rbx + 84]
        vmovss  dword ptr [rbx + 84], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 88]
        vaddss  xmm1, xmm1, dword ptr [rbx + 88]
        vmovss  dword ptr [rbx + 88], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 92]
        vaddss  xmm1, xmm1, dword ptr [rbx + 92]
        vmovss  dword ptr [rbx + 92], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 96]
        vaddss  xmm1, xmm1, dword ptr [rbx + 96]
        vmovss  dword ptr [rbx + 96], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 100]
        vaddss  xmm1, xmm1, dword ptr [rbx + 100]
        vmovss  dword ptr [rbx + 100], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 104]
        vaddss  xmm1, xmm1, dword ptr [rbx + 104]
        vmovss  dword ptr [rbx + 104], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 108]
        vaddss  xmm1, xmm1, dword ptr [rbx + 108]
        vmovss  dword ptr [rbx + 108], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 112]
        vaddss  xmm1, xmm1, dword ptr [rbx + 112]
        vmovss  dword ptr [rbx + 112], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 116]
        vaddss  xmm1, xmm1, dword ptr [rbx + 116]
        vmovss  dword ptr [rbx + 116], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 120]
        vaddss  xmm1, xmm1, dword ptr [rbx + 120]
        vmovss  dword ptr [rbx + 120], xmm1
        vmulss  xmm0, xmm0, dword ptr [r14 + 124]
        vaddss  xmm0, xmm0, dword ptr [rbx + 124]
        vmovss  dword ptr [rbx + 124], xmm0
        add     rsp, 8
        pop     rbx
        pop     r14
        ret
float dot_prod<float, 32ul>(float const*, float const*):          # @float dot_prod<float, 32ul>(float const*, float const*)
        vmovups ymm0, ymmword ptr [rsi]
        vmovups ymm1, ymmword ptr [rsi + 32]
        vmovups ymm2, ymmword ptr [rsi + 64]
        vmulps  ymm2, ymm2, ymmword ptr [rdi + 64]
        vmovups ymm3, ymmword ptr [rsi + 96]
        vmulps  ymm0, ymm0, ymmword ptr [rdi]
        vaddps  ymm0, ymm0, ymm2
        vmulps  ymm2, ymm3, ymmword ptr [rdi + 96]
        vmulps  ymm1, ymm1, ymmword ptr [rdi + 32]
        vaddps  ymm1, ymm1, ymm2
        vaddps  ymm0, ymm0, ymm1
        vextractf128    xmm1, ymm0, 1
        vaddps  xmm0, xmm0, xmm1
        vpermilpd       xmm1, xmm0, 1           # xmm1 = xmm0[1,0]
        vaddps  xmm0, xmm0, xmm1
        vmovshdup       xmm1, xmm0              # xmm1 = xmm0[1,1,3,3]
        vaddss  xmm0, xmm0, xmm1
        vzeroupper
        ret
NlmsFilter<float, 32ul>::mu:
        .long   0x3f000000                      # float 0.5

NlmsFilter<float, 32ul>::eps:
        .long   0x3f800000                      # float 1

.L.str:
        .asciz  "signal.size() <= N + 1"

.L.str.1:
        .asciz  "/app/example.cpp"

.L__PRETTY_FUNCTION__.NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>):
        .asciz  "F NlmsFilter<float, 32>::predict(std::span<const F>) [F = float, N = 32]"
c++ clang vectorization simd auto-vectorization
1个回答
1
投票

问题是 clang 无法确定

w
是否(部分)与
signal
(或
signal
指向的数据)别名。

在这种情况下,

GCC 会比较

w.data()
signal.data()
指针。如果它们差异足够大,它将向量化循环,否则它会退回到标量循环。 Clang 可能认为潜在的性能增益并不能证明额外测试的成本是合理的(如果增加数组大小,在某些时候确实如此)。

您可以使用

__restrict
关键字告诉 GCC/Clang 指向的内存不会重叠(这是非标准 C++。C 确实提供了
restrict
关键字)。但是,这只能应用于不直接指向
std::span
的指针(据我所知)。您可以通过提供一个辅助函数来解决此问题,该函数将其输出地址作为
__restrict
-ed 普通指针,例如:

template<typename F, size_t N>
inline void scale(F * __restrict out, F const* in, F scale)
{
    for(size_t i=0; i < N; ++i)
    {
        out[i] = scale * in[i];
    }
}

而不是你的循环,这样称呼它:

scale<F, N>(w.data(), signal.data(), nu*e);

修改后的 godbolt-链接:https://godbolt.org/z/WWbdz489e

最新问题
© www.soinside.com 2019 - 2024. All rights reserved.