快速查找百万位数字中连续 1 位的最大数量

问题描述 投票:0回答:1

比如123456789用二进制表示就是

111010110111100110100010101
。连续 1 位的最大数目是
4
。我有兴趣有效地解决非常大的数字(百万位甚至更多)。我想到了这个:

def onebits(n):
    ctr = 0
    while n:
        n &= n >> 1
        ctr += 1
    return ctr

n &= n >> 1
同时切断每条连续 1 位的最高 1 位。我重复直到每条条纹都消失,计算它走了多少步。例如(所有二进制):

   11101011 (start value)
->  1100001
->   100000
->        0

需要三步,因为最长的连胜有三个 1 位。

对于 random 数字,其中条纹很短,这是 相当快(在该基准中为

Kelly3
)。但是对于具有长条纹的数字,它最多需要 O(b²) 时间,其中 b 是 n 的大小(以位为单位)。我们可以做得更好吗?

python performance binary biginteger
1个回答
1
投票

是的,我们可以通过进一步发展该想法在 O(b log b) 时间内完成。 使用 指数搜索.

请注意,通过切断每条条纹的顶部 1 位,这也会扩大条纹之间的间隙。最初,条纹至少由 one 0 位分隔。在切断每条条纹的第一个 1 位后,条纹现在至少被 two 0 位分开。

然后我们可以做

n &= n >> 2
来切断所有剩余条纹的前 two 1 位。这也将差距扩大到至少 four 0 位。

只要我们仍然有 1 位条纹,我们就会继续从每个条纹的开始切断 4、8、16、32 等 1 位。

比方说,当我们试图切断 32 时,我们发现我们没有留下任何连胜。此时我们切换到反向模式。尝试切断 16。然后是 8、4、2,最后是 1。但只保留仍然让我们保持连胜的削减。

代码:

def onebits_linearithmic(n):
    if n == 0:
        return 0
    total = cut = 1
    while m := n & (n >> cut):
        n = m
        total += cut
        cut *= 2
    while cut := cut // 2:
        if m := n & (n >> cut):
            n = m
            total += cut
    return total

带有 随机 1,000,000 位数字的基准

  0.58 ± 0.06 ms  onebits_linearithmic
  1.18 ± 0.09 ms  onebits_quadratic
 44.10 ± 1.24 ms  onebits_linear

我包含了一个使用字符串的线性时间解决方案,但它的隐藏常数要高得多,所以它仍然慢得多。

接下来,具有 50,000 位连续 1 位的 100,000 位随机数

  0.11 ± 0.00 ms  onebits_linearithmic
  2.03 ± 0.19 ms  onebits_linear
173.89 ± 2.51 ms  onebits_quadratic

二次解确实变慢了很多。其他两个保持快速,所以让我们用 随机 1,000,000 位数字和 500,000 位连续 1 位的数字来尝试它们

 1.29 ± 0.02 ms  onebits_linearithmic
21.59 ± 0.26 ms  onebits_linear

完整代码(在线尝试!):

def onebits_quadratic(n):
    ctr = 0
    while n:
        n &= n >> 1
        ctr += 1
    return ctr

def onebits_linearithmic(n):
    if n == 0:
        return 0
    total = cut = 1
    while m := n & (n >> cut):
        n = m
        total += cut
        cut *= 2
    while cut := cut // 2:
        if m := n & (n >> cut):
            n = m
            total += cut
    return total

def onebits_linear(n):
    return max(map(len, filter(None, f'{n:b}'.split('0'))), default=0)

funcs = onebits_quadratic, onebits_linearithmic, onebits_linear

import random
from timeit import repeat
from statistics import mean, stdev

# Correctness
for n in [*range(10000), *(random.getrandbits(10) for _ in range(1000))]:
  expect = funcs[0](n)
  for f in funcs[1:]:
    if f(n) != expect:
      print('fail:', f.__name__)
    
def test(bits, number, unit, scale, long_streak, funcs):
  if not long_streak:
    print(f'random {bits:,}-bit numbers:')
  else:
    print(f'random {bits:,}-bit numbers with {bits//2:,}-bit streak of 1-bits:')

  times = {f: [] for f in funcs}
  def stats(f):
    ts = [t * scale for t in times[f][5:]]
    return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} {unit} '

  for _ in range(10):
    n = random.getrandbits(bits)
    if long_streak:
      n |= ((1 << (bits//2)) - 1) << (bits//4)
    for f in funcs:
      t = min(repeat(lambda: f(n), number=number)) / number
      times[f].append(t)

  for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)
  print()

test(1_000_000, 1, 'ms', 1e3, False, funcs)
test(100_000, 1, 'ms', 1e3, True, funcs)
test(1_000_000, 1, 'ms', 1e3, True, funcs[1:])
© www.soinside.com 2019 - 2024. All rights reserved.