设置位:
给定一个数组
int inds[N]
,其中每个inds[i]
是[0, 255]
范围内的1位位置(并且所有inds[i]
都已排序且唯一),我需要将__m256i
的相应位设置为1。
还有比我下面做的更好的方法吗:
alignas(32) uint64_t buf[4] = {0};
for (int i = 0; i < N; ++i) {
int ind = inds[i];
buf[ind / 64] |= 1ul << (ind % 64);
}
auto r = _mm256_load_si256((__m256i*)buf);
获取位:
在相反的操作中,我需要计算位 1 位置处的双精度值的乘积。即,给定
double const sizes[256]
计算其中一些的乘积(在 __m256i
掩码给出的位置)。
inline
double size (__m256i r, double const sizes[256])
{
alignas(16) uint64_t buf[4];
_mm256_store_si256((__m256i*)buf, r);
double s[4] = {1.0, 1.0, 1.0, 1.0};
// __builtin_ctzl(i) gives next position
// and i &= i - 1 clears that bit
for (; buf[0] != 0; buf[0] &= buf[0] - 1)
s[0] *= sizes[__builtin_ctzl(buf[0]) + 0 * 64];
for (; buf[1] != 0; buf[1] &= buf[1] - 1)
s[1] *= sizes[__builtin_ctzl(buf[1]) + 1 * 64];
for (; buf[2] != 0; buf[2] &= buf[2] - 1)
s[2] *= sizes[__builtin_ctzl(buf[2]) + 2 * 64];
for (; buf[3] != 0; buf[3] &= buf[3] - 1)
s[3] *= sizes[__builtin_ctzl(buf[3]) + 3 * 64];
return s[0] * s[1] * s[2] * s[3];
}
同样的问题:更好的方法吗?
使用 AVX512 可以做到这一点,并且在某些情况下比标量方法更有效。但使用我使用的方法,当
N
较低时,它不一定有帮助,并且将其移植到 AVX2(这似乎是可能的)会使情况在这方面变得更糟。我会把它放在底部。也许有更好的方法。
还有其他一些问题,标量方法有一个可以解决的问题:通过内存进行循环传递依赖。例如GCC编译这样的代码,(相关部分摘录)
.L3:
movzx eax, BYTE PTR [rdi]
mov rdx, r8
add rdi, 1
mov rcx, rax
shr rax, 6
sal rdx, cl
or QWORD PTR [rsp-32+rax*8], rdx
cmp rsi, rdi
jne .L3
or
在(大多数)连续循环迭代中加载/存储相同的内存位置。可以通过为结果的每个块编写单独的循环来避免这种情况,
__m256i set_indexed_bits2(uint8_t* indexes, size_t N)
{
alignas(32) uint64_t buf[4] = { 0 };
if (N < 256)
indexes[N] = 255;
size_t i = 0;
while (indexes[i] < 64)
buf[0] |= 1ull << indexes[i++];
while (indexes[i] < 128)
buf[1] |= 1ull << indexes[i++];
while (indexes[i] < 192)
buf[2] |= 1ull << indexes[i++];
while (i < N)
buf[3] |= 1ull << indexes[i++];
return _mm256_load_si256((__m256i*)buf);
}
在源代码级别,看起来仍然存在通过内存的依赖关系,但是当以这种方式编写时(数组中的索引是常量),编译器可能会应用优化,暂时使用寄存器来表示
buf[0]
和等等,例如,这里是 GCC 的摘录(这相当代表其他编译器的做法):
.L15:
add rax, 1
mov r11, rdi
sal r11, cl
movzx ecx, BYTE PTR [rdx+rax]
or rsi, r11
cmp cl, -65
jbe .L15
mov QWORD PTR [rsp-16], rsi
好多了(尽管 GCC 错过了将
bts
与寄存器目标一起使用的机会,这与具有内存目标的版本不同)。事实上,在我的测试中,效果好两倍多,但这取决于 N
和其他因素。
这里是 AVX512 的黑客攻击。在我的 PC(rocket Lake)上,这比某些
N
的改进标量代码更快,大约 60 或更多。并不令人惊奇,但确实是这样。转换到 AVX2 似乎是可能的,但这会使它开始值得的门槛更高。如果有更好的方法有望改变事情。
__m512i indexes_to_bits64(__m512i indexes, __mmask64 valids)
{
__m512i one = _mm512_set1_epi64(1);
__m512i mask = _mm512_set1_epi64(63);
uint64_t m = valids;
// do all the shifts
__m128i i = _mm512_castsi512_si128(indexes);
__m512i b0 = _mm512_maskz_sllv_epi64(m, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
i = _mm512_castsi512_si128(_mm512_permutex_epi64(indexes, 1));
__m512i b1 = _mm512_maskz_sllv_epi64(m >> 8, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
i = _mm512_castsi512_si128(_mm512_permutex_epi64(indexes, 2));
__m512i b2 = _mm512_maskz_sllv_epi64(m >> 16, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
i = _mm512_castsi512_si128(_mm512_permutex_epi64(indexes, 3));
__m512i b3 = _mm512_maskz_sllv_epi64(m >> 24, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
indexes = _mm512_shuffle_i64x2(indexes, indexes, _MM_SHUFFLE(1, 0, 3, 2));
i = _mm512_castsi512_si128(indexes);
__m512i b4 = _mm512_maskz_sllv_epi64(m >> 32, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
i = _mm512_castsi512_si128(_mm512_permutex_epi64(indexes, 1));
__m512i b5 = _mm512_maskz_sllv_epi64(m >> 40, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
i = _mm512_castsi512_si128(_mm512_permutex_epi64(indexes, 2));
__m512i b6 = _mm512_maskz_sllv_epi64(m >> 48, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
i = _mm512_castsi512_si128(_mm512_permutex_epi64(indexes, 3));
__m512i b7 = _mm512_maskz_sllv_epi64(m >> 56, one, _mm512_and_epi64(_mm512_cvtepu8_epi64(i), mask));
// OR the 8 parts together
__m512i b012 = _mm512_ternarylogic_epi64(b0, b1, b2, 0xFE);
__m512i b345 = _mm512_ternarylogic_epi64(b3, b4, b5, 0xFE);
__m512i b67 = _mm512_or_epi64(b6, b7);
return _mm512_ternarylogic_epi64(b012, b345, b67, 0xFE);
}
__m256i set_indexed_bits_avx512(uint8_t* indexes, int N)
{
// load values 0..63 into one chunk,
// 64..127 in the next chunk
// 128..191 in the third chunk
// 192..255 in the last chunk
// this automatically expanded based on bits 7 and 6
__m512i chunk0 = _mm512_loadu_epi8(indexes);
__mmask64 valids0 = _mm512_cmple_epu8_mask(chunk0, _mm512_set1_epi8(63));
int chunk0_count = std::countr_one(valids0);
valids0 = _bzhi_u64(valids0, N);
__m512i chunk1 = _mm512_loadu_epi8(indexes + chunk0_count);
__mmask64 valids1 = _mm512_cmple_epu8_mask(chunk1, _mm512_set1_epi8(127));
int chunk1_count = std::countr_one(valids1);
valids1 = _bzhi_u64(valids1, std::max(0, N - chunk0_count));
__m512i chunk2 = _mm512_loadu_epi8(indexes + chunk0_count + chunk1_count);
__mmask64 valids2 = _mm512_cmple_epu8_mask(chunk2, _mm512_set1_epi8(191));
int chunk2_count = std::countr_one(valids2);
valids2 = _bzhi_u64(valids2, std::max(0, N - chunk0_count - chunk1_count));
__m512i chunk3 = _mm512_loadu_epi8(indexes + chunk0_count + chunk1_count + chunk2_count);
__mmask64 valids3 = _bzhi_u64(-1ULL, std::max(0, N - chunk0_count - chunk1_count - chunk2_count));
// 1 << bottom 6 bits
chunk0 = indexes_to_bits64(chunk0, valids0);
chunk1 = indexes_to_bits64(chunk1, valids1);
chunk2 = indexes_to_bits64(chunk2, valids2);
chunk3 = indexes_to_bits64(chunk3, valids3);
// interleave and reduce horizontally
__m512i chunk01 = _mm512_or_epi64(
_mm512_unpacklo_epi64(chunk0, chunk1),
_mm512_unpackhi_epi64(chunk0, chunk1));
__m512i chunk23 = _mm512_or_epi64(
_mm512_unpacklo_epi64(chunk2, chunk3),
_mm512_unpackhi_epi64(chunk2, chunk3));
__m256i chunk01_2 = _mm256_or_si256(_mm512_castsi512_si256(chunk01), _mm512_extracti64x4_epi64(chunk01, 1));
__m256i chunk23_2 = _mm256_or_si256(_mm512_castsi512_si256(chunk23), _mm512_extracti64x4_epi64(chunk23, 1));
__m128i chunk01_3 = _mm_or_si128(_mm256_castsi256_si128(chunk01_2), _mm256_extracti128_si256(chunk01_2, 1));
__m128i chunk23_3 = _mm_or_si128(_mm256_castsi256_si128(chunk23_2), _mm256_extracti128_si256(chunk23_2, 1));
return _mm256_inserti128_si256(_mm256_castsi128_si256(chunk01_3), chunk23_3, 1);
}