我正在尝试使用自定义排序谓词构建一个堆。由于进入其中的值是“用户定义”类型,因此我无法修改它们的内置比较谓词。
有没有办法做这样的事情:
h = heapq.heapify([...], key=my_lt_pred)
h = heapq.heappush(h, key=my_lt_pred)
或者更好的是,我可以将
heapq
函数包装在我自己的容器中,这样我就不需要继续传递谓词。
根据heapq文档,自定义堆顺序的方法是让堆上的每个元素都是一个元组,第一个元组元素是接受正常Python比较的元素。
heapq 模块中的函数有点麻烦(因为它们不是面向对象的),并且总是要求我们的堆对象(堆化列表)作为第一个参数显式传递。我们可以通过创建一个非常简单的包装类来一石二鸟,该包装类将允许我们指定一个
key
函数,并将堆呈现为一个对象。
下面的类保留一个内部列表,其中每个元素都是一个元组,其中第一个成员是一个键,在元素插入时使用
key
参数计算,在堆实例化时传递:
# -*- coding: utf-8 -*-
import heapq
class MyHeap(object):
def __init__(self, initial=None, key=lambda x:x):
self.key = key
self.index = 0
if initial:
self._data = [(key(item), i, item) for i, item in enumerate(initial)]
self.index = len(self._data)
heapq.heapify(self._data)
else:
self._data = []
def push(self, item):
heapq.heappush(self._data, (self.key(item), self.index, item))
self.index += 1
def pop(self):
return heapq.heappop(self._data)[2]
(额外的
self.index
部分是为了避免当评估的键值是平局并且存储的值不可直接比较时发生冲突 - 否则 heapq 可能会因 TypeError 而失败)
定义一个类,在其中重写
__lt__()
函数。请参阅下面的示例(适用于 Python 3.7):
import heapq
class Node(object):
def __init__(self, val: int):
self.val = val
def __repr__(self):
return f'Node value: {self.val}'
def __lt__(self, other):
return self.val < other.val
heap = [Node(2), Node(0), Node(1), Node(4), Node(2)]
heapq.heapify(heap)
print(heap) # output: [Node value: 0, Node value: 2, Node value: 1, Node value: 4, Node value: 2]
heapq.heappop(heap)
print(heap) # output: [Node value: 1, Node value: 2, Node value: 2, Node value: 4]
heapq 文档建议堆元素可以是元组,其中第一个元素是优先级并定义排序顺序。
与您的问题更相关的是,该文档包含“示例代码的讨论”,说明如何实现自己的 heapq 包装函数来处理排序稳定性和具有同等优先级的元素问题(以及其他问题)。 简而言之,他们的解决方案是让 heapq 中的每个元素都是具有优先级、条目计数和要插入的元素的三元组。条目计数确保具有相同优先级的元素按照添加到堆的顺序进行排序。
用它来比较 heapq 中对象的值
假设您有以下(姓名、年龄)列表
a = [('Tim',4), ('Radha',9), ('Rob',7), ('Krsna',3)]
并且您希望通过将它们添加到最小堆中来根据其年龄对该列表进行排序,而不是编写所有自定义比较器内容,您可以在将元组推送到队列之前翻转元组内容的顺序。这是因为 heapq.heappush() 默认按元组的第一个元素排序。 像这样:
import heapq
heap = []
heapq.heapify(heap)
for element in a:
heapq.heappush(heap, (element[1],element[0]))
如果这适合您的工作并且您不想编写混乱的自定义比较器,那么这是一个简单的技巧。
同样,它默认按升序对值进行排序。如果要按年龄降序排序,请翻转内容并使元组第一个元素的值为负数:
import heapq
heap = []
heapq.heapify(heap)
for element in a:
heapq.heappush(heap, (-element[1],element[0]))
cmp_to_key
模块中的
functools
。 cpython源代码. 假设您需要一个三元组的优先级队列,并使用最后一个属性指定优先级。
from heapq import *
from functools import cmp_to_key
def mycmp(triplet_left, triplet_right):
key_l, key_r = triplet_left[2], triplet_right[2]
if key_l > key_r:
return -1 # larger first
elif key_l == key_r:
return 0 # equal
else:
return 1
WrapperCls = cmp_to_key(mycmp)
pq = []
myobj = tuple(1, 2, "anystring")
# to push an object myobj into pq
heappush(pq, WrapperCls(myobj))
# to get the heap top use the `obj` attribute
inner = pq[0].obj
性能测试:
代码
from functools import cmp_to_key
from timeit import default_timer as time
from random import randint
from heapq import *
class WrapperCls1:
__slots__ = 'obj'
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
kl, kr = self.obj[2], other.obj[2]
return True if kl > kr else False
def cmp_class2(obj1, obj2):
kl, kr = obj1[2], obj2[2]
return -1 if kl > kr else 0 if kl == kr else 1
WrapperCls2 = cmp_to_key(cmp_class2)
triplets = [[randint(-1000000, 1000000) for _ in range(3)] for _ in range(100000)]
# tuple_triplets = [tuple(randint(-1000000, 1000000) for _ in range(3)) for _ in range(100000)]
def test_cls1():
pq = []
for triplet in triplets:
heappush(pq, WrapperCls1(triplet))
def test_cls2():
pq = []
for triplet in triplets:
heappush(pq, WrapperCls2(triplet))
def test_cls3():
pq = []
for triplet in triplets:
heappush(pq, (-triplet[2], triplet))
start = time()
for _ in range(10):
test_cls1()
# test_cls2()
# test_cls3()
print("total running time (seconds): ", -start+(start:=time()))
结果
list
代替
tuple
,每个功能:
WrapperCls1:16.2ms__slots__
__lt__()
函数和
__slots__
属性的自定义类稍快一些。的答案,我通过扩展元组创建了一个最大优先级队列:
import heapq
class MaxTuple(tuple):
def __lt__(self, other):
return self[0] > other[0]
my_tuples = [(2, "orange"), (1, "red"), (5, "blue"), (3, "yellow"), (4, "green")]
my_queue = [MaxTuple(t) for t in my_tuples]
heapq.heapify(my_queue)
while my_queue:
print(heapq.heappop(my_queue))
将堆从最大弹出到最小:
(5, 'blue')
(4, 'green')
(3, 'yellow')
(2, 'orange')
(1, 'red')