似乎与 python 的多处理相关的序列化和反序列化限制了并行处理数据的好处。
在下面的示例中,我创建了一个返回 numpy 数组的自定义迭代。随着numpy数组大小的增加,数据获取过程成为瓶颈。这是预料之中的。然而,我预计增加
num_worker
和 prefetch_factor
将通过提前准备批次来减少这一瓶颈。但我在下面的示例中没有看到这种行为。
我测试了两种情况,其中
MyIterable
返回
np.array((10, 150))
np.array((1000, 150))
两种场景下处理一个批次的平均时间如下:
# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253
对于小物体,当
num_workers
增加时,每个批次的时间会按预期下降。但对于较大的物体,变化不大。我将其归因于工作进程必须序列化 np 对象,然后主进程将其反序列化。物体越大,花费的时间就越长。
但是,有了足够大的
num_worker
和prefetch_factor
,数据加载器中的队列不应该总是被填满,这样数据获取就不会成为瓶颈吗?
此外,更改
prefetch_factor
不会改变任何内容。 prefetch_factor
有什么意义?该文档说主进程预加载了 num_worker * prefetch_factor
批次,但是您可以这样做,但对于减少瓶颈没有任何效果。
我在这个问题中添加了更详细的逐步分析,以供参考。
import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset
def collate_fn(records):
# some custom collation function
return records
class MyIterable(object):
def __init__(self, n):
self.n = n
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i < self.n:
sleep(0.003125) # simulates data fetch time
# return np.random.random((10, 150)) # small data item
return np.random.random((1000, 150)) # large data item
else:
raise StopIteration
class MyIterableDataset(IterableDataset):
def __init__(self, n):
super(MyIterableDataset).__init__()
self.n = n
def __iter__(self):
return MyIterable(self.n)
def get_performance_metrics(num_workers):
ds = MyIterableDataset(n=10000)
if num_workers == 0:
dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
else:
dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
batch_size=128, collate_fn=collate_fn,
multiprocessing_context='spawn')
warmup = 5
times = []
t0 = time.perf_counter()
for i, batch in enumerate(dl):
sleep(0.05) # simulates train step
e = time.perf_counter()
if i >= warmup:
times.append(e - t0)
t0 = time.perf_counter()
if i >= 20:
break
print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')
if __name__ == '__main__':
num_worker_options = [0, 2, 4, 6, 8]
for n in num_worker_options:
get_performance_metrics(n)
无论队列有多大,填满队列都需要时间。
如果您的迭代器以足够高的频率调用 next,它将赶上向队列添加数据的处理器。
理想情况下,要感觉没有瓶颈,time_to_add_data <= time_to_process_data.
如果 time_to_add_data > time_to_process_data,num_workers 可能会有所帮助,但是
如果 time_to_add_data >> time_to_process_data,num_workers 将无关紧要,第一个工作线程将成为处理的瓶颈。