为什么 clang 17.0.1 不会矢量化以下函数中的循环:
void adapt(std::span<const F, N + 1> signal)
{
F true_val = signal.back();
F y = dot_prod<F, N>(&signal[0], w.begin());
F e = true_val - y;
F dot = dot_prod<F, N>(&signal[0], &signal[0]);
F nu = mu / (eps + dot);
for (size_t i = 0; i < N; i++)
{
w[i] += nu * e * signal[i];
}
}
不存在附带依赖性或浮点关联性问题,GCC 13.2 对其进行矢量化没有任何问题。
其背景是我正在尝试优化我的代码以使用矢量化点积。通常,由于浮点关联性问题,
std::inner_product
会发出标量实现,除非您使用 -ffast-math
。但是,我只希望 -ffast-math
应用于单个函数,并且我一直在寻找一种可移植的方法来为 clang 和 GCC 执行此操作。在查看输出时,我注意到 clang 不会矢量化另一个循环。
这是完整的 C++:
#include <cstddef>
#include <span>
#include <numeric>
#include <cassert>
#pragma float_control(precise, off, push)
template <typename F, size_t N>
__attribute__((optimize("-ffast-math")))
constexpr inline F dot_prod(const F *a, const F *b)
{
F acc = 0.0f;
for(size_t i = 0; i < N; i++)
acc += a[i] * b[i];
return acc;
}
#pragma float_control(pop)
template <typename F, size_t N> class NlmsFilter
{
static constexpr F mu = 0.5;
static constexpr F eps = 1.0;
std::array<F, N> w = {};
public:
F predict(std::span<const F> signal)
{
assert(signal.size() <= N + 1);
if (signal.size() == N + 1)
{
auto signal_static = signal.template subspan<0, N + 1>();
adapt(signal_static);
return dot_prod<F, N>(&signal_static[1], w.begin());
}
else if (signal.size() == 0)
{
return 0.0f;
}
else
{
return signal.back();
}
}
void adapt(std::span<const F, N + 1> signal)
{
F true_val = signal.back();
F y = dot_prod<F, N>(&signal[0], w.begin());
F e = true_val - y;
F dot = dot_prod<F, N>(&signal[0], &signal[0]);
F nu = mu / (eps + dot);
for (size_t i = 0; i < N; i++)
{
w[i] += nu * e * signal[i];
}
}
};
template class NlmsFilter<float, 32>;
这是 clang 的汇编输出:
NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>): # @NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>)
push r14
push rbx
push rax
cmp rdx, 34
jae .LBB0_6
test rdx, rdx
je .LBB0_2
mov rbx, rsi
cmp rdx, 33
jne .LBB0_4
mov r14, rdi
mov rsi, rbx
call NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>)
add rbx, 4
mov rdi, rbx
mov rsi, r14
add rsp, 8
pop rbx
pop r14
jmp float dot_prod<float, 32ul>(float const*, float const*) # TAILCALL
.LBB0_2:
vxorps xmm0, xmm0, xmm0
add rsp, 8
pop rbx
pop r14
ret
.LBB0_4:
vmovss xmm0, dword ptr [rbx + 4*rdx - 4] # xmm0 = mem[0],zero,zero,zero
add rsp, 8
pop rbx
pop r14
ret
.LBB0_6:
mov edi, offset .L.str
mov esi, offset .L.str.1
mov ecx, offset .L__PRETTY_FUNCTION__.NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>)
mov edx, 27
call __assert_fail
.LCPI1_0:
.long 0x3f800000 # float 1
.LCPI1_1:
.long 0x3f000000 # float 0.5
NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>): # @NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>)
push r14
push rbx
push rax
mov r14, rsi
mov rbx, rdi
vmovss xmm0, dword ptr [rsi + 128] # xmm0 = mem[0],zero,zero,zero
vmovss dword ptr [rsp + 4], xmm0 # 4-byte Spill
mov rdi, rsi
mov rsi, rbx
call float dot_prod<float, 32ul>(float const*, float const*)
vmovss xmm1, dword ptr [rsp + 4] # 4-byte Reload
vsubss xmm0, xmm1, xmm0
vmovss dword ptr [rsp + 4], xmm0 # 4-byte Spill
mov rdi, r14
mov rsi, r14
call float dot_prod<float, 32ul>(float const*, float const*)
vaddss xmm0, xmm0, dword ptr [rip + .LCPI1_0]
vmovss xmm1, dword ptr [rip + .LCPI1_1] # xmm1 = mem[0],zero,zero,zero
vdivss xmm0, xmm1, xmm0
vmulss xmm0, xmm0, dword ptr [rsp + 4] # 4-byte Folded Reload
vmulss xmm1, xmm0, dword ptr [r14]
vaddss xmm1, xmm1, dword ptr [rbx]
vmovss dword ptr [rbx], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 4]
vaddss xmm1, xmm1, dword ptr [rbx + 4]
vmovss dword ptr [rbx + 4], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 8]
vaddss xmm1, xmm1, dword ptr [rbx + 8]
vmovss dword ptr [rbx + 8], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 12]
vaddss xmm1, xmm1, dword ptr [rbx + 12]
vmovss dword ptr [rbx + 12], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 16]
vaddss xmm1, xmm1, dword ptr [rbx + 16]
vmovss dword ptr [rbx + 16], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 20]
vaddss xmm1, xmm1, dword ptr [rbx + 20]
vmovss dword ptr [rbx + 20], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 24]
vaddss xmm1, xmm1, dword ptr [rbx + 24]
vmovss dword ptr [rbx + 24], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 28]
vaddss xmm1, xmm1, dword ptr [rbx + 28]
vmovss dword ptr [rbx + 28], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 32]
vaddss xmm1, xmm1, dword ptr [rbx + 32]
vmovss dword ptr [rbx + 32], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 36]
vaddss xmm1, xmm1, dword ptr [rbx + 36]
vmovss dword ptr [rbx + 36], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 40]
vaddss xmm1, xmm1, dword ptr [rbx + 40]
vmovss dword ptr [rbx + 40], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 44]
vaddss xmm1, xmm1, dword ptr [rbx + 44]
vmovss dword ptr [rbx + 44], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 48]
vaddss xmm1, xmm1, dword ptr [rbx + 48]
vmovss dword ptr [rbx + 48], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 52]
vaddss xmm1, xmm1, dword ptr [rbx + 52]
vmovss dword ptr [rbx + 52], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 56]
vaddss xmm1, xmm1, dword ptr [rbx + 56]
vmovss dword ptr [rbx + 56], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 60]
vaddss xmm1, xmm1, dword ptr [rbx + 60]
vmovss dword ptr [rbx + 60], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 64]
vaddss xmm1, xmm1, dword ptr [rbx + 64]
vmovss dword ptr [rbx + 64], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 68]
vaddss xmm1, xmm1, dword ptr [rbx + 68]
vmovss dword ptr [rbx + 68], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 72]
vaddss xmm1, xmm1, dword ptr [rbx + 72]
vmovss dword ptr [rbx + 72], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 76]
vaddss xmm1, xmm1, dword ptr [rbx + 76]
vmovss dword ptr [rbx + 76], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 80]
vaddss xmm1, xmm1, dword ptr [rbx + 80]
vmovss dword ptr [rbx + 80], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 84]
vaddss xmm1, xmm1, dword ptr [rbx + 84]
vmovss dword ptr [rbx + 84], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 88]
vaddss xmm1, xmm1, dword ptr [rbx + 88]
vmovss dword ptr [rbx + 88], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 92]
vaddss xmm1, xmm1, dword ptr [rbx + 92]
vmovss dword ptr [rbx + 92], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 96]
vaddss xmm1, xmm1, dword ptr [rbx + 96]
vmovss dword ptr [rbx + 96], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 100]
vaddss xmm1, xmm1, dword ptr [rbx + 100]
vmovss dword ptr [rbx + 100], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 104]
vaddss xmm1, xmm1, dword ptr [rbx + 104]
vmovss dword ptr [rbx + 104], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 108]
vaddss xmm1, xmm1, dword ptr [rbx + 108]
vmovss dword ptr [rbx + 108], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 112]
vaddss xmm1, xmm1, dword ptr [rbx + 112]
vmovss dword ptr [rbx + 112], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 116]
vaddss xmm1, xmm1, dword ptr [rbx + 116]
vmovss dword ptr [rbx + 116], xmm1
vmulss xmm1, xmm0, dword ptr [r14 + 120]
vaddss xmm1, xmm1, dword ptr [rbx + 120]
vmovss dword ptr [rbx + 120], xmm1
vmulss xmm0, xmm0, dword ptr [r14 + 124]
vaddss xmm0, xmm0, dword ptr [rbx + 124]
vmovss dword ptr [rbx + 124], xmm0
add rsp, 8
pop rbx
pop r14
ret
float dot_prod<float, 32ul>(float const*, float const*): # @float dot_prod<float, 32ul>(float const*, float const*)
vmovups ymm0, ymmword ptr [rsi]
vmovups ymm1, ymmword ptr [rsi + 32]
vmovups ymm2, ymmword ptr [rsi + 64]
vmulps ymm2, ymm2, ymmword ptr [rdi + 64]
vmovups ymm3, ymmword ptr [rsi + 96]
vmulps ymm0, ymm0, ymmword ptr [rdi]
vaddps ymm0, ymm0, ymm2
vmulps ymm2, ymm3, ymmword ptr [rdi + 96]
vmulps ymm1, ymm1, ymmword ptr [rdi + 32]
vaddps ymm1, ymm1, ymm2
vaddps ymm0, ymm0, ymm1
vextractf128 xmm1, ymm0, 1
vaddps xmm0, xmm0, xmm1
vpermilpd xmm1, xmm0, 1 # xmm1 = xmm0[1,0]
vaddps xmm0, xmm0, xmm1
vmovshdup xmm1, xmm0 # xmm1 = xmm0[1,1,3,3]
vaddss xmm0, xmm0, xmm1
vzeroupper
ret
NlmsFilter<float, 32ul>::mu:
.long 0x3f000000 # float 0.5
NlmsFilter<float, 32ul>::eps:
.long 0x3f800000 # float 1
.L.str:
.asciz "signal.size() <= N + 1"
.L.str.1:
.asciz "/app/example.cpp"
.L__PRETTY_FUNCTION__.NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>):
.asciz "F NlmsFilter<float, 32>::predict(std::span<const F>) [F = float, N = 32]"
问题是 clang 无法确定
w
是否(部分)与 signal
(或 signal
指向的数据)别名。
在这种情况下,GCC 会比较
w.data()
和 signal.data()
指针。如果它们差异足够大,它将向量化循环,否则它会退回到标量循环。 Clang 可能认为潜在的性能增益并不能证明额外测试的成本是合理的(如果增加数组大小,在某些时候确实如此)。
您可以使用
__restrict
关键字告诉 GCC/Clang 指向的内存不会重叠(这是非标准 C++。C 确实提供了 restrict
关键字)。但是,这只能应用于不直接指向 std::span
的指针(据我所知)。您可以通过提供一个辅助函数来解决此问题,该函数将其输出地址作为 __restrict
-ed 普通指针,例如:
template<typename F, size_t N>
inline void scale(F * __restrict out, F const* in, F scale)
{
for(size_t i=0; i < N; ++i)
{
out[i] = scale * in[i];
}
}
而不是你的循环,这样称呼它:
scale<F, N>(w.data(), signal.data(), nu*e);
修改后的 godbolt-链接:https://godbolt.org/z/WWbdz489e