如何在比较函数中优化多个独立条件分支?

问题描述 投票:3回答:2
struct Obj {
  int x;
  int y;
  int z;
};

int Compare(Obj* a, Obj* b) {
  if (a->x > b->x) return 1;
  else if (a->x < b->x) return -1;

  if (a->y > b->y) return 1;
  else if (a->y < b->y) return -1;

  if (a->z > b->z) return 1;
  else if (a->z < b->z) return -1;

  return 0;
}

如上面的代码所示,最多有三个条件分支来获得比较结果。比较函数将由某种函数调用。如何优化代码以杀死条件分支,从而提高比较功能的性能?

-更新-由于调用方函数是快速排序的改进版本,因此需要更大,更少和相等的结果。因此,比较功能应该用-1,1,0区分三个结果。

c++ performance comparison micro-optimization branch-prediction
2个回答
3
投票

C ++不是汇编语言,并且编译器可以如果需要,将当前函数编译为无分支汇编。 (取消引用结构指针以加载一个成员意味着,即使存在完整的struct对象,并且即使C ++抽象机不会接触y或z成员,也可以通过推测方式读取,而不会出错)。您最关心哪种架构?

您是否尝试使用配置文件引导的优化进行编译,以便编译器可以看到分支是不可预测的?这可能导致它根据目标ISA将if()转换为无分支cmov或其他方法。 (使用rand() & 0x7等生成随机数据,因此对象具有相等的x和y并实际上达到z的情况并不罕见。)


可以使用SIMD查找第一个不匹配的元素,然后返回该元素的差异[]。例如,x86 SIMD具有movemask操作,可以将向量比较结果转换为整数位掩码,我们可以将其与bitscan指令一起使用以找到第一个或最后一个设置位。

((这取决于能否从您的12字节结构中安全地读取16字节,假设x86。这种情况是这样,只要您的数组不以页面末尾的最后一个元素结尾, Is it safe to read past the end of a buffer within the same page on x86 and x64?通常是,并且广泛用于有效实现strlen和类似功能。)

((ARM NEON没有方便的移动蒙版,因此对于ARM / AArch64,如果SIMD绝对是赢家,则最好将SIMD向量内的数据改组以得出结果。ARM可能不是谓词比较指令,或使用仍比x86 CMOV更好的AArch64的更有限的无分支条件指令。)

SIMD可以为我们提供良好的吞吐量,但与@Scheff的branchless arithmetic version in comments相比,延迟可能较差,尤其是在像现代x86这样的宽管线上,它可以并行执行许多独立的工作(例如将单独的比较结果转换为布尔整数)。在QSort中,高延迟可能不是理想的选择,在QSort中,您期望分支错误预测不会罕见。重叠的独立比较与乱序执行仅在正确预测了分支时才起作用。

要从两个int值获得+ / 0 /-结果,可以将其强制转换为int64_t并减去。这样可以避免出现签名溢出的可能性,并且在64位ISA上非常有效。 (或者,如果可以内联,则理想情况下可以编译为32位带符号的比较而不是实际的减法。32位减法可能具有带符号的UB溢出,并且在包装时会丢失结果)。如果您不需要标准化为+1 / 0 / -1,请执行此操作。

[我在带有数组的联合体内使用匿名结构来扩展@Scheff's handy benchmark frameworkwith bugfix),而没有将所有内容从a->x更改为a->vals.x

#include <stdint.h>
#include <immintrin.h>

union Obj {
  struct { // extension: anonymous struct
    int x;
    int y;
    int z;
  };
  int elems[3];
};



// a better check would be on value ranges; sizeof can include padding
static_assert( sizeof(int64_t) > sizeof(int), "we need int smaller than int64_t");

int64_t compare_x86(const Obj *a, const Obj *b)
{
    __m128i va = _mm_loadu_si128((const __m128i*)a);  // assume over-read is safe, last array object isn't at the end of a page.
    __m128i vb = _mm_loadu_si128((const __m128i*)b);
    __m128i veq = _mm_cmpeq_epi32(va,vb);

    unsigned eqmsk = _mm_movemask_ps(_mm_castsi128_ps(veq));
    eqmsk |= 1<<2;   // set elems[2]'s bit so we'll return that (non)diff if they're all equal
    unsigned firstdiff = __builtin_ctz(eqmsk);   // GNU C extension: count trailing zeros

    // sign-extend to 64-bit first so overflow is impossible, giving a +, 0, or - result
    return a->elems[firstdiff] - (int64_t)b->elems[firstdiff];
}

[x86-64带有GCC9.3 On Godbolt-O3 -march=skylake -fno-tree-vectorize,对于非内联情况,它将编译为该asm:

compare_x86(Obj const*rdi, Obj const*rsi):
        vmovdqu xmm1, XMMWORD PTR [rsi]
        vpcmpeqd        xmm0, xmm1, XMMWORD PTR [rdi]
        vmovmskps       edx, xmm0               # edx = bitmask of the vector compare result
        or      edx, 4
        tzcnt   edx, edx                        # rdx = index of lowest set bit
        mov     edx, edx                        # stupid compiler, already zero-extended to 64-bit
        movsx   rax, DWORD PTR [rdi+rdx*4]      # 32->64 sign extending load
        movsx   rdx, DWORD PTR [rsi+rdx*4]
        sub     rax, rdx                        # return value in RAX
        ret

延迟关键路径

通过SIMD比较,通过移动掩码返回整数tzcnt / bsf(在Intel上为3个周期),然后对movsx个负载进行另一个L1d负载使用延迟(5个周期)。直到tzcnt之后才知道加载地址。因此,这里的ILP很少。但是,它可以在独立比较之间很好地重叠,并且总的uop计数很低,因此前端带宽的瓶颈也不太严重。

未对齐的SIMD负载不会对Intel CPU造成任何损失,除非它们越过缓存行边界。则等待时间是额外的10个周期左右。甚至更糟的是,如果它们越过4k边界,尤其是在Skylake使页面拆分便宜很多之前的Intel上。对于随机的4字节对齐的对象地址,在16个起始位置中有3个导致高速缓存行拆分负载(对于64B高速缓存行)。这进一步增加了从准备好输入地址到准备好比较结果的平均延迟,并且不能与任何工作重叠。

没有-march=skylake时,GCC使用单独的movdqu未对齐负载,以及与rep bsf相同的指令tzcnt。没有BMI1的CPU会将其解码为纯bsf。 (它们仅在输入为零时才不同;我们确保不会发生这种情况。AMD上bsf的速度很慢,与英特尔上tzcnt的速度相同。)

使用@Scheff的基准(对结果进行计数),它在禁用自动矢量化时比普通标量“算术”版本要快一些。 (GCC可以自动计算出算术版本。)运行之间的时序结果不一致,因为测试用例太小并且编译器资源管理器运行的AWS服务器可能具有不同的CPU频率,尽管它们都是Skylake-avx512。但是在一次运行中,在此和arith之间交替,这样的结果很典型:

compare_x86() 5. try: 28 mus (<: 3843, >: 3775)
compareArithm() 5. try: 59 mus (<: 4992, >: 5007)
compare_x86() 6. try: 39 mus (<: 3843, >: 3775)
compareArithm() 6. try: 64 mus (<: 4992, >: 5007)
compare_x86() 7. try: 27 mus (<: 3843, >: 3775)
compareArithm() 7. try: 64 mus (<: 4992, >: 5007)

但是请记住,这只是add]] <0>0返回值,因此受吞吐量限制,而不是延迟。新的比较可以开始,而对先前的比较结果没有任何数据依赖性或控制依赖性。


嗯,我本可以使用pmovmskb来获取每个字节的高位,而不是使用ps版本的每个双字,但是C使得在int数组中使用字节偏移量而不是在元素偏移量。在asm中,先输入tzcnt或BSF,然后再输入movsx rax, [rdi + rdx]。这样可以节省SIMD整数pcmpeqd和SIMD-FP movmskps之间旁路延迟的延迟周期。但是要从编译器中获取该信息,您可能必须将其转换为char*进行指针添加,然后再转换回int*


[我一开始以为使用_mm_cmpgt_epi32(va,vb)来得到一个向量,该向量比较有符号大于的结果,但是后来我意识到索引原始结构就像映射正确的元素或位一样容易。变成-1 / +1整数。

如果要对全等情况进行特殊处理,则可以将位#3设置为[|= 1<<3),然后在这种罕见情况下分支,但其余部分仍然无分支。

    eqmsk |= 1<<3;   // set the 4th bit so there's a non-zero bit to find
    unsigned firstdiff = __builtin_ctz(eqmsk);

    if (firstdiff >= 3)   // handle this rare(?) case with a branch
        return 0;

    ... something with  (a < b) * 2 - 1

混合分支策略:

[如果很少x是相等的,则可以考虑

   if (a->x != b->x)
       return  a->x - (int_fast64_t)b->x;
   else {
       8-byte branchless SIMD?
       or maybe just 2 element branchless scalar
   }

IDK,如果仅需再添加2个元素就值得进行SIMD。可能不是。

或者也许考虑对x和y进行无分支,并在y分量上分支等于跳跃标量z?如果您的对象在int的大部分范围内都是随机的,那么很难找到两个仅在最后一个成分上有所不同的对象。

[我认为,良好的排序算法通过避免多余的比较来减少比较次数的方式可能会在结果模式中产生更多的熵,并且可能还会增加对在最终排序顺序中彼此“接近”的元素进行的比较数量。因此,如果有很多元素的x相等,QSort可能会做更多的比较,这些比较确实需要检查y元素。

这是三向比较的通用模拟:

#include <tuple>

namespace util {
    template <typename T>
    int compare(const T& lhs, const T& rhs)
    {
        if (lhs == rhs) {
            return 0;
        } else if (lhs < rhs) {
            return -1;
        } else {
            return 1;
        }
    }

    namespace detail {
        template <typename Tuple>
        int compare_tuples(const Tuple&, const Tuple&, std::index_sequence<>)
        {
            return 0;
        }
        template <typename Tuple, std::size_t I, std::size_t... Is>
        int compare_tuples(const Tuple& lhs, const Tuple& rhs, std::index_sequence<I, Is...>)
        {
            if (auto cmp = compare(std::get<I>(lhs), std::get<I>(rhs))) {
                return cmp;
            } else {
                return compare_tuples(lhs, rhs, std::index_sequence<Is...>{});
            }
        }
    }

    template <typename Tuple>
    int compare_tuples(const Tuple& lhs, const Tuple& rhs)
    {
        return detail::compare_tuples(
            lhs, rhs, std::make_index_sequence<std::tuple_size_v<Tuple>>{}
        );
    }
}

然后您可以通过使用std::tie形成成员的元组来使用它:

struct Object {
    int x, y, z;
};

int compare(const Object& lhs, const Object& rhs)
{
    return util::compare_tuples(
        std::tie(lhs.x, lhs.y, lhs.z),
        std::tie(rhs.x, rhs.y, rhs.z)
    );
}

([live demo

compare函数最终是GCC的optimized

compare(Object const&, Object const&):
        mov     eax, DWORD PTR [rsi]
        cmp     DWORD PTR [rdi], eax
        je      .L11
.L2:
        setge   al
        movzx   eax, al
        lea     eax, [rax-1+rax]
        ret
.L11:
        mov     eax, DWORD PTR [rsi+4]
        cmp     DWORD PTR [rdi+4], eax
        jne     .L2
        mov     edx, DWORD PTR [rsi+8]
        xor     eax, eax
        cmp     DWORD PTR [rdi+8], edx
        jne     .L2
        ret

和叮当声:

compare(Object const&, Object const&):                 # @compare(Object const&, Object const&)
        mov     ecx, dword ptr [rdi]
        xor     eax, eax
        cmp     ecx, dword ptr [rsi]
        setge   cl
        jne     .LBB0_1
        mov     ecx, dword ptr [rdi + 4]
        xor     eax, eax
        cmp     ecx, dword ptr [rsi + 4]
        setge   cl
        jne     .LBB0_1
        mov     eax, dword ptr [rdi + 8]
        xor     ecx, ecx
        xor     edx, edx
        cmp     eax, dword ptr [rsi + 8]
        setge   dl
        lea     eax, [rdx + rdx - 1]
        cmove   eax, ecx
        ret

自C ++ 20起,此问题可以通过默认的宇宙飞船运算符轻松解决:

#include <compare>

struct Obj {
    int x;
    int y;
    int z;

    constexpr auto operator<=>(const Obj&) const = default;
};

int to_int(std::partial_ordering cmp) noexcept
{
    if (cmp == 0) {
        return 0;
    } else if (cmp < 0) {
        return -1;
    } else {
        return 1;
    }
}

int Compare(Obj* a, Obj* b)
{
    return to_int(*a <=> *b);
}

0
投票

这是三向比较的通用模拟:

© www.soinside.com 2019 - 2024. All rights reserved.