将4个uint16_t打包在uint64_t中的快速模12算法

问题描述 投票:9回答:5

考虑以下联合:

union Uint16Vect {
    uint16_t _comps[4];
    uint64_t _all;
};

是否存在用于确定每个分量是否等于1模12的快速算法?

简单的代码序列是:

Uint16Vect F(const Uint16Vect a) {
    Uint16Vect r;
    for (int8_t k = 0; k < 4; k++) {
        r._comps[k] = (a._comps[k] % 12 == 1) ? 1 : 0;
    }
    return r;
}
c algorithm vectorization modulo avx2
5个回答
12
投票

像这样的常数除法应该用multiplication by the multiplicative inverse完成。如您所见,编译器优化x/12 to x*43691 >> 19

x/12

由于SSE / AVX中有乘法指令,因此可以很容易地将其向量化。此外,可以将x*43691 >> 19简化为bool h(uint16_t x) { return x % 12 == 1; } h(unsigned short): movzx eax, di imul eax, eax, 43691 ; = 0xFFFF*8/12 + 1 shr eax, 19 lea eax, [rax+rax*2] sal eax, 2 sub edi, eax cmp di, 1 sete al ret ,然后将其转换为x = (x % 12 == 1) ? 1 : 0;,从而避免了常数表要比较的值1。您可以使用x = (x % 12 == 1),以便gcc自动为您生成代码

x = (x - 1) % 12 == 0

下面是vector extension

typedef uint16_t ymm __attribute__((vector_size(32)));
ymm mod12(ymm x)
{
    return !!((x - 1) % 12);
}

Clang和ICC在向量类型上不支持output from gcc,因此您需要更改为mod12(unsigned short __vector(16)): vpcmpeqd ymm3, ymm3, ymm3 ; ymm3 = -1 vpaddw ymm0, ymm0, ymm3 vpmulhuw ymm1, ymm0, YMMWORD PTR .LC0[rip] ; multiply with 43691 vpsrlw ymm2, ymm1, 3 vpsllw ymm1, ymm2, 1 vpaddw ymm1, ymm1, ymm2 vpsllw ymm1, ymm1, 2 vpcmpeqw ymm0, ymm0, ymm1 vpandn ymm0, ymm0, ymm3 ret 。不幸的是,似乎编译器不支持!!发出MMX指令。但是现在您还是应该使用SSE或AVX

(x - 1) % 12 == 0的输出较短,如您在上面的同一Godbolt链接中所见,但是您需要一个包含1s的表进行比较,这可能会更好。检查哪种方法在您的情况下工作更快


或者对于这样小的输入范围,您可以使用查找表。基本版本需要65536个元素的数组

__attribute__((vector_size(8))

要使用,只需将x % 12 == 1替换为#define S1(x) ((x) + 0) % 12 == 1, ((x) + 1) % 12 == 1, ((x) + 2) % 12 == 1, ((x) + 3) % 12 == 1, \ ((x) + 4) % 12 == 1, ((x) + 4) % 12 == 1, ((x) + 6) % 12 == 1, ((x) + 7) % 12 == 1 #define S2(x) S1((x + 0)*8), S1((x + 1)*8), S1((x + 2)*8), S1((x + 3)*8), \ S1((x + 4)*8), S1((x + 4)*8), S1((x + 6)*8), S1((x + 7)*8) #define S3(x) S2((x + 0)*8), S2((x + 1)*8), S2((x + 2)*8), S2((x + 3)*8), \ S2((x + 4)*8), S2((x + 4)*8), S2((x + 6)*8), S2((x + 7)*8) #define S4(x) S3((x + 0)*8), S3((x + 1)*8), S3((x + 2)*8), S3((x + 3)*8), \ S3((x + 4)*8), S3((x + 4)*8), S3((x + 6)*8), S3((x + 7)*8) bool mod12e1[65536] = { S4(0U), S4(8U), S4(16U), S4(24U), S4(32U), S4(40U), S4(48U), S4(56U) } 。这当然可以向量化

但是由于结果只有1或0,所以您也可以使用65536位数组将大小减小到仅8KB


您还可以通过4除数和3来检查12除数。4除数显然很简单。 3的除数可以通过多种方式计算

  • x % 12 == 1中一样计算奇数位和偶数和]之间的,并检查是否可以被3整除
  • 或者,您可以检查以2为底的数字之和[[2k

  • (例如以4、16、64为底的数字),以查看是否可以被3整除。之所以有用,是因为在以mod12e1[x]为基数来检查גלעד ברקן's answer的任何除数n的除数中,只需检查数字的总和是否可被n整除。这是它的实现

    b

    将3除以b - 1的功劳

由于自动矢量化的程序集输出太长,您可以在void modulo12equals1(uint16_t d[], uint32_t size) { for (uint32_t i = 0; i < size; i++) { uint16_t x = d[i] - 1; bool divisibleBy4 = x % 4 == 0; x = (x >> 8) + (x & 0x00ff); // max 1FE x = (x >> 4) + (x & 0x000f); // max 2D bool divisibleBy3 = !!((01111111111111111111111ULL >> x) & 1); d[i] = divisibleBy3 && divisibleBy4; } } 上进行检查

另请参见

  • Godbolt link
  • How to know if a binary number divides by 3?
  • Determine whether or not a binary number is divisible by 3
  • Bit representation and divisibility by 3
  • building circuit for divisibility by 3

  • 2
    投票
    [如果这有助于将操作限制在位操作和Check if a number is divisible by 3上,我们可以观察到有效的候选者必须通过两次测试,因为减1必须表示4和3的可除性。首先,最后两位必须为Logic to check the number is divisible by 3 or not? 。然后除以3,我们可以通过从偶数位置的popcount中减去奇数位置的popcount来找到。

    2
    投票
    这是我能想到的最好的

    1
    投票
    [最近有const evenMask = parseInt('1010101010101010', 2); // Leave out first bit, we know it will be zero // after subtracting 1 const oddMask = parseInt('101010101010100', 2); console.log('n , Test 1: (n & 3)^3, Test 2: popcount diff:\n\n'); for (let n=0; n<500; n++){ if (n % 12 == 1) console.log( n, (n & 3)^3, popcount(n & evenMask) - popcount(n & oddMask)) } // https://stackoverflow.com/questions/43122082/efficiently-count-the-number-of-bits-in-an-integer-in-javascript function popcount(n) { var tmp = n; var count = 0; while (tmp > 0) { tmp = tmp & (tmp - 1); count++; } return count; }关于快速余数计算和除数检查。例如,您可以使用uint64_t F(uint64_t vec) { //512 = 4 mod 12 -> max val 0x3FB vec = ((vec & 0xFE00FE00FE00FE00L) >> 7) + (vec & 0x01FF01FF01FF01FFL); //64 = 4 mod 12 -> max val 0x77 vec = ((vec & 0x03C003C003C003C0L) >> 4) + (vec & 0x003F003F003F003FL); //16 = 4 mod 12 -> max val 0x27 vec = ((vec & 0x0070007000700070L) >> 2) + (vec & 0x000F000F000F000FL); //16 = 4 mod 12 -> max val 0x13 vec = ((vec & 0x0030003000300030L) >> 2) + (vec & 0x000F000F000F000FL); //16 = 4 mod 12 -> max val 0x0f vec = ((vec & 0x0030003000300030L) >> 2) + (vec & 0x000F000F000F000FL); //Each field is now 4 bits, and only 1101 and 0001 are 1 mod 12. //The top 2 bits must be equal and the other2 must be 0 and 1 return vec & ~(vec>>1) & ~((vec>>2)^(vec>>3)) & 0x0001000100010001L; } 检查除数是否为12,或者使用post on Daniel Lemire's blog进行32位运算。这应该易于并行化,但是与phuclv的答案中的代码不同,它需要32位乘法。

    0
    投票
    请注意,((x * 43691) & 0x7ffff) < 43691表示x * 357913942 < 357913942,而后者非常便宜。
    © www.soinside.com 2019 - 2024. All rights reserved.