我的问题很简单,但我觉得很难说清楚,所以请允许我一步步解释。
假设我有
N
项和 N
对应的索引。
每个项目都可以使用相应的索引加载。
def load_item(index: int) -> ItemType:
# Mostly just reading, but very slow.
return item
我还有一个函数,它需要两个(加载的)项目并计算分数。
def calc_score(item_a: ItemType, item_b: ItemType) -> ScoreType:
# Much faster than load function.
return score
请注意
calc_score(a, b) == calc_score(b, a)
。
我想要做的是计算所有 2 项组合的分数,并找到(至少)一个给出最高分数的组合。
这可以按如下方式实现:
def dumb_solution(n: int) -> Tuple[int, int]:
best_score = 0
best_combination = None
for index_a, index_b in itertools.combinations(range(n), 2):
item_a = load_item(index_a)
item_b = load_item(index_b)
score = calc_score(item_a, item_b)
if score > best_score:
best_score = score
best_combination = (index_a, index_b)
return best_combination
但是,这个解决方案调用了
load_item
函数N*(N+1)
次,这是这个函数的瓶颈。
这可以通过使用缓存来解决。 然而不幸的是,这些项目太大了,不可能将所有项目都保存在内存中。 因此,我们需要使用大小受限的缓存。
from functools import lru_cache
@lru_cache(maxsize=M)
def load(index: int) -> ItemType:
# Very slow process.
return item
请注意,
M
(缓存大小)远小于 N
(大约 N // 10
到 N // 2
)。
问题是典型的组合顺序对于 LRU 缓存来说并不理想。
例如,当
N=6, M=3
时,itertools.combinations
生成以下序列,load_item
函数的调用次数为17次。
[
(0, 1), # 1, 2
(0, 2), # -, 3
(0, 3), # -, 4
(0, 4), # -, 5
(0, 5), # -, 6
(1, 2), # 7, 8
(1, 3), # -, 9
(1, 4), # -, 10
(1, 5), # -, 11
(2, 3), # 12, 13
(2, 4), # -, 14
(2, 5), # -, 15
(3, 4), # 16, 17
(3, 5), # -, -
(4, 5), # -, -
]
但是,如果我将上面的顺序重新排列如下,则调用次数将是 10。
[
(0, 1), # 1, 2
(0, 2), # -, 3
(1, 2), # -, -
(0, 3), # -, 4
(2, 3), # -, -
(0, 4), # -, 5
(3, 4), # -, -
(0, 5), # -, 6
(4, 5), # -, -
(1, 4), # 7, -
(1, 5), # -, -
(1, 3), # -, 8
(3, 5), # -, -
(2, 5), # 9, -
(2, 4), # -, 10
]
如何生成一系列 2 项组合以最大化缓存命中率?
我提出的解决方案是优先考虑缓存中已有的项目。
from collections import OrderedDict
def prioritizes_item_already_in_cache(n, cache_size):
items = list(itertools.combinations(range(n), 2))
cache = OrderedDict()
reordered = []
def update_cache(x, y):
cache[x] = cache[y] = None
cache.move_to_end(x)
cache.move_to_end(y)
while len(cache) > cache_size:
cache.popitem(last=False)
while items:
# Find a pair where both are cached.
for i, (a, b) in enumerate(items):
if a in cache and b in cache:
reordered.append((a, b))
update_cache(a, b)
del items[i]
break
else:
# Find a pair where one of them is cached.
for i, (a, b) in enumerate(items):
if a in cache or b in cache:
reordered.append((a, b))
update_cache(a, b)
del items[i]
break
else:
# Cannot find item in cache.
a, b = items.pop(0)
reordered.append((a, b))
update_cache(a, b)
return reordered
对于
N=100, M=10
,此序列导致 1660 个调用,大约是典型序列的 1/3。对于 N=100, M=50
,只有 155 个呼叫。所以我想我可以说这是一个有前途的方法。
不幸的是,这个功能对于大
N
来说太慢而且没用。
我没能完成N=1000
,但实际数据有几万。
此外,它没有考虑在没有找到缓存项时如何选择项。
因此,即使它很快,它在理论上是否是最佳解决方案也是值得怀疑的(所以请注意我的问题不是如何使上述函数更快)。
这是要测试的代码。
import functools
import itertools
import time
from collections import OrderedDict
from typing import Callable, Iterable, Tuple
ItemType = int
ScoreType = int
def load_item(index: int) -> ItemType:
return int(index) # dummy for test
def calc_score(item_a: ItemType, item_b: ItemType) -> ScoreType:
return abs(item_a - item_b) # dummy for test
class LRUCacheWithCounter:
def __init__(self, maxsize: int):
def wrapped_func(key):
self.load_count += 1
# print(key, self.load_count)
return load_item(key)
self.__cache = functools.lru_cache(maxsize=maxsize)(wrapped_func)
self.load_count = 0
def __call__(self, key: int) -> int:
return self.__cache(key)
def basic_loop(iterator: Iterable[Tuple[int, int]], cached_load: Callable[[int], int]):
best_score = 0
best_combination = None
for i, j in iterator:
a = cached_load(i)
b = cached_load(j)
score = calc_score(a, b)
if score > best_score:
best_score = score
best_combination = (i, j)
return best_score, best_combination
def baseline(n: int, cache_size: int) -> Iterable[Tuple[int, int]]:
return list(itertools.combinations(range(n), 2))
def prioritizes_item_already_in_cache(n: int, cache_size: int) -> Iterable[Tuple[int, int]]:
items = list(itertools.combinations(range(n), 2))
cache = OrderedDict()
reordered = []
def update_cache(x, y):
cache[x] = cache[y] = None
cache.move_to_end(x)
cache.move_to_end(y)
while len(cache) > cache_size:
cache.popitem(last=False)
while items:
# Find a pair where both are cached.
for i, (a, b) in enumerate(items):
if a in cache and b in cache:
reordered.append((a, b))
update_cache(a, b)
del items[i]
break
else:
# Find a pair where one of them is cached.
for i, (a, b) in enumerate(items):
if a in cache or b in cache:
reordered.append((a, b))
update_cache(a, b)
del items[i]
break
else:
# Cannot find item in cache.
a, b = items.pop(0)
reordered.append((a, b))
update_cache(a, b)
return reordered
# def your_solution_here(n: int, cache_size: int) -> Iterable[Tuple[int, int]]:
# pass
def benchmark():
n = 100 # N
cache_size = 30 # M
def run(func):
started = time.perf_counter()
reordered = func(n, cache_size)
elapsed = time.perf_counter() - started
cache = LRUCacheWithCounter(cache_size)
score, comb = basic_loop(iterator=reordered, cached_load=cache)
print(f"{func.__name__}: {cache.load_count=}, {elapsed=}, {score=}, {comb=}")
run(baseline)
run(prioritizes_item_already_in_cache)
# run(your_solution_here)
if __name__ == "__main__":
benchmark()
我对您不使用上述测试代码或更改
basic_loop
或 LRUCacheWithCounter
的行为没有问题。
补充说明:
感谢您读完这篇长文。
这是一个简单的递归定义的排序,它不依赖于缓存大小,并且在基准测试中获得 566 次负载:
def cache_oblivious(n: int, cache_size: int) -> Iterable[Tuple[int, int]]:
dest = []
def findPairs(lo1: int, n1: int, lo2: int, n2: int):
if n1 < 1 or n2 < 1:
return
if n1 == 1:
for i in range(max(lo1+1,lo2), lo2+n2):
dest.append((lo1, i))
elif n2 == 1:
for i in range(lo1, min(lo1+n1, lo2)):
dest.append((i, lo2))
elif n1 >= n2:
half = n1//2
findPairs(lo1, half, lo2, n2)
findPairs(lo1+half, n1-half, lo2, n2)
else:
half = n2//2
findPairs(lo1, n1, lo2, half)
findPairs(lo1, n1, lo2+half, n2-half)
findPairs(0,n,0,n)
return dest