使用最大堆和最小堆的第 K 个最小数的流

问题描述 投票:0回答:2

我正在编写 Python 程序来计算

m[i]
的总和,这是 A[0], A[1], ..., A[i] 值中的 (i//3+1)
th
最小值,使用 maxHeap 和 minHeap。

这是我写的代码。

import heapq

def find_m(A):
    n = len(A)
    min_heap = []
    max_heap = []
    m = []
    for i in range(n):
        if len(max_heap) < i // 3 + 1:
            heapq.heappush(max_heap, -A[i])
        else:
            if A[i] < -max_heap[0]:
                heapq.heappush(min_heap, -heapq.heappop(max_heap))
                heapq.heappush(max_heap, -A[i])
            else:
                heapq.heappush(min_heap, A[i])
        m.append(-max_heap[0])
    return sum(m)


_list = [11, 12, -20, 14, -10, -8, -7, -6, -4, -2]
print(find_m(_list))

预期产出:

-38

我得到的输出是:

-27

应该已经总结的值(

m
)是:

11, 11, -20, 11, -10, -10, -8, -8, -8, -7

但是我发现我的代码添加了这些值:

11, 11, -20, 14, -10, -10, -7, -7, -7, -2

我的意图是将较小的值存储在最大堆中,其大小应为

i//3+1
,并将较大的值存储在最小堆中。这样,最大堆的根应该始终具有所需的值。

很明显,在处理值 14 的那一刻出了问题。该值已添加到结果中,而应该添加 11。

我犯了什么错误?

algorithm sorting heap
2个回答
1
投票

每时每刻都必须确保两个堆的根处于正确的关系中,即当两个堆都非空时,必须确保以下不变量:

-max_heap[0] <= min_heap[0]

您的代码并不总是这样做。例如,在第一个

if
块中,一个值被添加到最大堆中,但结果可能违反了上述不变量。这是在您的示例中添加值 14 时的情况,因为此时最大堆的元素太少(因此它需要一个),但最小堆的根值为 12,因此它确实应该是该值 12被移动到最大堆。最终的
else
块可能会出现类似的问题(在相反的意义上)。

这里是更正:

def find_m(A):
    min_heap = []
    max_heap = []
    m = []
    for i, value in enumerate(A):  # Use enumerate to get both the index and the value
        if len(max_heap) < i // 3 + 1:
            # If current value would cause a violation of the invariant, exchange it 
            #    with the minimum value in the min heap
            if min_heap and value > min_heap[0]:   
                value = heapq.heappushpop(min_heap, value)
            heapq.heappush(max_heap, -value)
        else:
            # If current value would cause a violation of the invariant, exchange it 
            #    with the maximum value in the max heap
            if value < -max_heap[0]:
                value = -heapq.heappushpop(max_heap, -value)
            heapq.heappush(min_heap, value)
        m.append(-max_heap[0])
    return sum(m)

0
投票

确保 trincot 提到的不变量的另一种方法。我的

heappushpop(max_heap, -a)
确保
max_heap
仍然具有最小的数字,因此在推入
min_heap
之后,新数字被合并到堆中并且不变量保持不变。然后只需在需要时将一个号码从
min_heap
转移到
max_heap
即可。

from heapq import heappush, heappop, heappushpop

def find_m(A):
    min_heap = []
    max_heap = []
    total = 0
    for i, a in enumerate(A):
        heappush(min_heap, -heappushpop(max_heap, -a))
        if not i % 3:
            heappush(max_heap, -heappop(min_heap))
        total -= max_heap[0]
    return total


# Your test
_list = [11, 12, -20, 14, -10, -8, -7, -6, -4, -2]
print(find_m(_list))

# Inefficient but obviously correct reference solution
def naive(A):
    return sum(
        sorted(A[:i+1])[i//3]
        for i in range(len(A))
    )

# Random tests
import random
for n in range(100):
    for _ in range(10):
        A = random.choices(range(1000), k=n)
        expect = naive(A)
        result = find_m(A)
        assert result == expect
print ('done')

在线尝试!

Benchmarks,出于好奇,有 100,000 个随机数:

 66.58 ± 0.75 ms  trincot4
 69.71 ± 0.43 ms  trincot2
 70.01 ± 0.81 ms  trincot3
 78.47 ± 1.13 ms  Kelly2
 84.78 ± 0.90 ms  Kelly
 98.80 ± 1.66 ms  trincot

代码(在线尝试!):

# my original
def Kelly(A): 
    min_heap = []
    max_heap = []
    total = 0
    for i, a in enumerate(A):
        heappush(min_heap, -heappushpop(max_heap, -a))
        if not i % 3:
            heappush(max_heap, -heappop(min_heap))
        total -= max_heap[0]
    return total

# optimized with local variables push/pop/pushpop
def Kelly2(A):
    push, pop, pushpop = heappush, heappop, heappushpop
    min_heap = []
    max_heap = []
    total = 0
    for i, a in enumerate(A):
        push(min_heap, -pushpop(max_heap, -a))
        if not i % 3:
            push(max_heap, -pop(min_heap))
        total -= max_heap[0]
    return total


# trincot's original
def trincot(A):
    min_heap = []
    max_heap = []
    m = []
    for i, value in enumerate(A):  # Use enumerate to get both the index and the value
        if len(max_heap) < i // 3 + 1:
            # If current value would cause a violation of the invariant, exchange it 
            #    with the minimum value in the min heap
            if min_heap and value > min_heap[0]:   
                value = heapq.heappushpop(min_heap, value)
            heapq.heappush(max_heap, -value)
        else:
            # If current value would cause a violation of the invariant, exchange it 
            #    with the maximum value in the max heap
            if value < -max_heap[0]:
                value = -heapq.heappushpop(max_heap, -value)
            heapq.heappush(min_heap, value)
        m.append(-max_heap[0])
    return sum(m)


# optimized to use `total`, `not i % 3`, and no `heapq.`
def trincot2(A):
    min_heap = []
    max_heap = []
    total = 0
    for i, value in enumerate(A):
        if not i % 3:
            if min_heap and value > min_heap[0]:   
                value = heappushpop(min_heap, value)
            heappush(max_heap, -value)
        else:
            if value < -max_heap[0]:
                value = -heappushpop(max_heap, -value)
            heappush(min_heap, value)
        total -= max_heap[0]
    return total


# Let heappushpop do the optimization
def trincot3(A):
    min_heap = []
    max_heap = []
    total = 0
    for i, value in enumerate(A):
        if not i % 3:
            value = heappushpop(min_heap, value)
            heappush(max_heap, -value)
        else:
            value = -heappushpop(max_heap, -value)
            heappush(min_heap, value)
        total -= max_heap[0]
    return total


# optimized with local push/pushpop
def trincot4(A):
    push, pushpop = heappush, heappushpop
    min_heap = []
    max_heap = []
    total = 0
    for i, value in enumerate(A):
        if not i % 3:
            push(max_heap, -pushpop(min_heap, value))
        else:
            push(min_heap, -pushpop(max_heap, -value))
        total -= max_heap[0]
    return total


funcs = Kelly, Kelly2, trincot, trincot2, trincot3, trincot4

from time import time
from statistics import mean, stdev
from heapq import heappush, heappop, heappushpop
import heapq
import random

times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e3 for t in sorted(times[f])[:5]]
    return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(25):
    A = random.choices(range(-10**9, 10**9), k=10**5)
    expect = None
    for f in funcs:
        t = time()
        result = f(A)
        times[f].append(time() - t)
        if expect is None:
            expect = result
        else:
            assert result == expect
for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)
© www.soinside.com 2019 - 2024. All rights reserved.