我有两个生成数据的生成器,例如:
def xs():
yield [1, 2]
yield [3, 4]
def ys():
yield [5, 6]
yield [7, 8]
我想处理所有可能的(x,y)对:
process([1, 2], [5, 6])
process([1, 2], [7, 8])
process([3, 4], [5, 6])
process([3, 4], [7, 8])
我能做到:
from itertools import product
for x, y in product(xs(), ys()):
process(x, y)
问题是:
process
可能会修改数据,例如这样:
def process(x, y):
print(f'process({x}, {y})')
x.pop()
y.pop()
然后发生的事情是这样的:
process([1, 2], [5, 6])
process([1], [7, 8])
process([3, 4], [5])
process([3], [7])
那是因为
product(xs(), ys())
仅创建所有 xs 和 ys 一次,并重用它们。因此,较早的 process
调用会影响后面调用的数据。我需要避免这种重复使用。
这个稍微好一点:
for x in xs():
for y in ys():
process(x, y)
这会重用每个
x
,但每个 y
都是新创建的,导致:
process([1, 2], [5, 6])
process([1], [7, 8])
process([3, 4], [5, 6])
process([3], [7, 8])
避免重复使用每个
x
的一种方法是始终进行深层复制:
from copy import deepcopy
for x in xs():
for y in ys():
process(deepcopy(x), y)
这给出了所需的行为。问题在于
deepcopy
可能比新生成数据慢得多。以下是 xs()
和 ys()
使用上述三种方法生成 100 个包含 100 个整数的列表(并且 process
不执行任何操作)的情况:
0.6 ± 0.0 ms using_product
3.1 ± 0.0 ms nested_loops
404.7 ± 25.9 ms with_deepcopy
怎样才能一直用fresh
x
和fresh y
不用deepcopy
,这样就快很多呢?应该只需要 nested_loops
的两倍左右的时间,因为它已经产生了一半的新鲜值。
基准/测试脚本:
def using_product(xs, ys, process):
for x, y in product(xs(), ys()):
process(x, y)
def nested_loops(xs, ys, process):
for x in xs():
for y in ys():
process(x, y)
def with_deepcopy(xs, ys, process):
for x in xs():
for y in ys():
process(deepcopy(x), y)
funcs = [
using_product,
nested_loops,
with_deepcopy,
]
from itertools import *
from copy import deepcopy
from timeit import timeit
from statistics import mean, stdev
import sys
import random
# The little example
def xs():
yield [1, 2]
yield [3, 4]
def ys():
yield [5, 6]
yield [7, 8]
def process(x, y):
print(f'process({x}, {y})')
x.pop()
y.pop()
for f in funcs:
print(f.__name__ + ':')
f(xs, ys, process)
print()
# Arguments for benchmark
def xs():
for _ in range(100):
yield [1] * 100
ys = xs
def process(x, y):
pass
# Run the benchmark
times = {f: [] for f in funcs}
def stats(f):
ts = [t * 1e3 for t in sorted(times[f])[:5]]
return f'{mean(ts):5.1f} ± {stdev(ts):3.1f} ms '
for _ in range(25):
random.shuffle(funcs)
for f in funcs:
t = timeit(lambda: f(xs, ys, process), number=1)
times[f].append(t)
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)
print('\nPython:', sys.version)
您的输入是几个数字列表,因此只需使用列表理解来复制它们。
def using_product_with_comprehension(xs, ys, process):
for x, y in product(xs(), ys()):
process([xcoord for xcoord in x], [ycoord for ycoord in y])
def nested_loops_with_comprehension(xs, ys, process):
for x in xs():
for y in ys():
process([xcoord for xcoord in x], y)
结果:
using_product:
process([1, 2], [5, 6])
process([1], [7, 8])
process([3, 4], [5])
process([3], [7])
nested_loops:
process([1, 2], [5, 6])
process([1], [7, 8])
process([3, 4], [5, 6])
process([3], [7, 8])
using_product_with_comprehension:
process([1, 2], [5, 6])
process([1, 2], [7, 8])
process([3, 4], [5, 6])
process([3, 4], [7, 8])
nested_loops_with_comprehension:
process([1, 2], [5, 6])
process([1, 2], [7, 8])
process([3, 4], [5, 6])
process([3, 4], [7, 8])
with_deepcopy:
process([1, 2], [5, 6])
process([1, 2], [7, 8])
process([3, 4], [5, 6])
process([3, 4], [7, 8])
0.7 ± 0.0 ms using_product
3.1 ± 0.1 ms nested_loops
21.4 ± 1.9 ms nested_loops_with_comprehension
36.1 ± 2.5 ms using_product_with_comprehension
372.8 ± 37.5 ms with_deepcopy
Python: 3.12.0 (main, Oct 7 2023, 10:42:35) [GCC 13.2.1 20230801]