为什么我不能用 XOR、AND 和左移将两个整数相加?

问题描述 投票:0回答:1

我试着在 LeetCode 问题 Sum of Two Integers 上运行这个

class Solution:
    def getSum(self, a: int, b: int) -> int:
        while b:
            withoutcarry = (a ^ b)
            b = (a & b) << 1
            a = withoutcarry
        return a

以下测试用例正确运行,LeetCode 的 26 个测试用例中总共有 8 个:

a, b = 1, 2

a, b = 2, 3

还有其他一些但不是这个:

a, b = -1, 1

超过时间限制

我尝试使用

mask
变量运行以下代码:

class Solution:
    def getSum(self, a: int, b: int) -> int:
        mask = 0xffffffff
        while b:
            sum_ = (a ^ b) & mask
            carry = ((a & b) << 1) & mask
            a = sum_
            b = carry
        if (a >> 31) & 1:
            return ~(a ^ mask)
        return a

它能够通过我之前列出的所有 3 个测试用例,并且在 LeetCode 的 26 个测试用例中,总共有 20 个没有通过:

a, b = -1, 0

它的输出是:

4294967295

然后我尝试了以下代码:

class Solution:
    def getSum(self, a: int, b: int) -> int:
        mask = 0xffffffff

        while b != 0:
            tmp = (a & b) << 1
            a = (a ^ b) & mask
            b = tmp & mask

        if a > mask // 2:
            return ~(a ^ mask)
        else:
            return a

这次它成功了并通过了所有测试用例,但我不知道为什么它有效。

python-3.x bit-manipulation bit-shift xor
1个回答
1
投票

您的第一个尝试解决方案根本不是为处理负数而设计的,一旦引入负数,它就会失败。这样做的主要原因是 Python 的“长”整数数据类型能够支持无限位大小。当您添加一个负数以得到一个正数时,您正在使用的“xor to add,然后如果有任何位可以携带则重复”算法会尝试永远重复。使用固定的位宽,清除符号位后最终会溢出数据类型的末尾,因此循环将停止。

例如,如果您的整数只有四位长,则添加

-1
1
会得到
a
b
的这些值:

a=0b1111 b=0b0001
a=0b1110 b=0b0010
a=0b1100 b=0b0100
a=0b1000 b=0b1000
a=0b0000 b=0b0000  # after << overflows, the loop stops here since `b` is zero

但是,Python 的整数永远不会溢出,所以在“最后”步骤中,您最终会得到

0b10000
代表
b
,而
a
仍然会有无限数量的前导
1
,并且您将永远向上携带该位,而不会到达符号位和循环结尾。

您的代码的第二个版本试图通过对值施加

mask
来解决溢出问题。它的缺陷是,如果
a
最初为负且
b
最初为零,则循环永远不会运行,因此掩码永远不会应用于
a
。这使得
a
仍然是负数,而其余代码期望所有
a
值在最后都是正数。这导致代码试图通过设置符号位来修复值来做错误的事情,并且不适当地将负值反转回正值。

代码的第三个版本更改了测试应该更仔细地反转的值的检查,因此按预期工作。如果

b
为零,它仍然允许负数通过未屏蔽,但它不会错误地将它们重新转换为正值。

如果您在函数开始时无条件地将掩码应用于

a
,而不是仅在循环中,您也可以使代码的第二个版本按预期工作:

class Solution:
    def getSum(self, a: int, b: int) -> int:
        mask = 0xffffffff
        a &= mask                                    # fix is here
        while b:
            sum_ = (a ^ b) & mask
            carry = ((a & b) << 1) & mask
            a = sum_
            b = carry
        if (a >> 31) & 1:
            return ~(a ^ mask)
        return a

现在,所有

a
值都将被屏蔽(将负值变为正值),最后的位检查将正确识别哪些结果应被视为负值。

© www.soinside.com 2019 - 2024. All rights reserved.