我正在编写 Python 程序来计算
m[i]
的总和,这是 A[0], A[1], ..., A[i]
值中的 (i//3+1)th最小值,使用 maxHeap 和 minHeap。
这是我写的代码。
import heapq
def find_m(A):
n = len(A)
min_heap = []
max_heap = []
m = []
for i in range(n):
if len(max_heap) < i // 3 + 1:
heapq.heappush(max_heap, -A[i])
else:
if A[i] < -max_heap[0]:
heapq.heappush(min_heap, -heapq.heappop(max_heap))
heapq.heappush(max_heap, -A[i])
else:
heapq.heappush(min_heap, A[i])
m.append(-max_heap[0])
return sum(m)
_list = [11, 12, -20, 14, -10, -8, -7, -6, -4, -2]
print(find_m(_list))
预期产出:
-38
我得到的输出是:
-27
应该已经总结的值(
m
)是:
11, 11, -20, 11, -10, -10, -8, -8, -8, -7
但是我发现我的代码添加了这些值:
11, 11, -20, 14, -10, -10, -7, -7, -7, -2
我的意图是将较小的值存储在最大堆中,其大小应为
i//3+1
,并将较大的值存储在最小堆中。这样,最大堆的根应该始终具有所需的值。
很明显,在处理值 14 的那一刻出了问题。该值已添加到结果中,而应该添加 11。
我犯了什么错误?
每时每刻都必须确保两个堆的根处于正确的关系中,即当两个堆都非空时,必须确保以下不变量:
-max_heap[0] <= min_heap[0]
您的代码并不总是这样做。例如,在第一个
if
块中,一个值被添加到最大堆中,但结果可能违反了上述不变量。这是在您的示例中添加值 14 时的情况,因为此时最大堆的元素太少(因此它需要一个),但最小堆的根值为 12,因此它确实应该是该值 12被移动到最大堆。最终的else
块可能会出现类似的问题(在相反的意义上)。
这里是更正:
def find_m(A):
min_heap = []
max_heap = []
m = []
for i, value in enumerate(A): # Use enumerate to get both the index and the value
if len(max_heap) < i // 3 + 1:
# If current value would cause a violation of the invariant, exchange it
# with the minimum value in the min heap
if min_heap and value > min_heap[0]:
value = heapq.heappushpop(min_heap, value)
heapq.heappush(max_heap, -value)
else:
# If current value would cause a violation of the invariant, exchange it
# with the maximum value in the max heap
if value < -max_heap[0]:
value = -heapq.heappushpop(max_heap, -value)
heapq.heappush(min_heap, value)
m.append(-max_heap[0])
return sum(m)
确保 trincot 提到的不变量的另一种方法。我的
heappushpop(max_heap, -a)
确保 max_heap
仍然具有最小的数字,因此在推入 min_heap
之后,新数字被合并到堆中并且不变量保持不变。然后只需在需要时将一个号码从min_heap
转移到max_heap
即可。
from heapq import heappush, heappop, heappushpop
def find_m(A):
min_heap = []
max_heap = []
total = 0
for i, a in enumerate(A):
heappush(min_heap, -heappushpop(max_heap, -a))
if not i % 3:
heappush(max_heap, -heappop(min_heap))
total -= max_heap[0]
return total
# Your test
_list = [11, 12, -20, 14, -10, -8, -7, -6, -4, -2]
print(find_m(_list))
# Inefficient but obviously correct reference solution
def naive(A):
return sum(
sorted(A[:i+1])[i//3]
for i in range(len(A))
)
# Random tests
import random
for n in range(100):
for _ in range(10):
A = random.choices(range(1000), k=n)
expect = naive(A)
result = find_m(A)
assert result == expect
print ('done')
Benchmarks,出于好奇,有 100,000 个随机数:
66.58 ± 0.75 ms trincot4
69.71 ± 0.43 ms trincot2
70.01 ± 0.81 ms trincot3
78.47 ± 1.13 ms Kelly2
84.78 ± 0.90 ms Kelly
98.80 ± 1.66 ms trincot
代码(在线尝试!):
# my original
def Kelly(A):
min_heap = []
max_heap = []
total = 0
for i, a in enumerate(A):
heappush(min_heap, -heappushpop(max_heap, -a))
if not i % 3:
heappush(max_heap, -heappop(min_heap))
total -= max_heap[0]
return total
# optimized with local variables push/pop/pushpop
def Kelly2(A):
push, pop, pushpop = heappush, heappop, heappushpop
min_heap = []
max_heap = []
total = 0
for i, a in enumerate(A):
push(min_heap, -pushpop(max_heap, -a))
if not i % 3:
push(max_heap, -pop(min_heap))
total -= max_heap[0]
return total
# trincot's original
def trincot(A):
min_heap = []
max_heap = []
m = []
for i, value in enumerate(A): # Use enumerate to get both the index and the value
if len(max_heap) < i // 3 + 1:
# If current value would cause a violation of the invariant, exchange it
# with the minimum value in the min heap
if min_heap and value > min_heap[0]:
value = heapq.heappushpop(min_heap, value)
heapq.heappush(max_heap, -value)
else:
# If current value would cause a violation of the invariant, exchange it
# with the maximum value in the max heap
if value < -max_heap[0]:
value = -heapq.heappushpop(max_heap, -value)
heapq.heappush(min_heap, value)
m.append(-max_heap[0])
return sum(m)
# optimized to use `total`, `not i % 3`, and no `heapq.`
def trincot2(A):
min_heap = []
max_heap = []
total = 0
for i, value in enumerate(A):
if not i % 3:
if min_heap and value > min_heap[0]:
value = heappushpop(min_heap, value)
heappush(max_heap, -value)
else:
if value < -max_heap[0]:
value = -heappushpop(max_heap, -value)
heappush(min_heap, value)
total -= max_heap[0]
return total
# Let heappushpop do the optimization
def trincot3(A):
min_heap = []
max_heap = []
total = 0
for i, value in enumerate(A):
if not i % 3:
value = heappushpop(min_heap, value)
heappush(max_heap, -value)
else:
value = -heappushpop(max_heap, -value)
heappush(min_heap, value)
total -= max_heap[0]
return total
# optimized with local push/pushpop
def trincot4(A):
push, pushpop = heappush, heappushpop
min_heap = []
max_heap = []
total = 0
for i, value in enumerate(A):
if not i % 3:
push(max_heap, -pushpop(min_heap, value))
else:
push(min_heap, -pushpop(max_heap, -value))
total -= max_heap[0]
return total
funcs = Kelly, Kelly2, trincot, trincot2, trincot3, trincot4
from time import time
from statistics import mean, stdev
from heapq import heappush, heappop, heappushpop
import heapq
import random
times = {f: [] for f in funcs}
def stats(f):
ts = [t * 1e3 for t in sorted(times[f])[:5]]
return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(25):
A = random.choices(range(-10**9, 10**9), k=10**5)
expect = None
for f in funcs:
t = time()
result = f(A)
times[f].append(time() - t)
if expect is None:
expect = result
else:
assert result == expect
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)