Pytorch DataLoader 中的 num_worker 和 prefetch_factor 无法扩展

问题描述 投票:0回答:1

似乎与 python 的多处理相关的序列化和反序列化限制了并行处理数据的好处。

在下面的示例中,我创建了一个返回 numpy 数组的自定义迭代。随着numpy数组大小的增加,数据获取过程成为瓶颈。这是预料之中的。然而,我预计增加

num_worker
prefetch_factor
将通过提前准备批次来减少这一瓶颈。但我在下面的示例中没有看到这种行为。

我测试了两种情况,其中

MyIterable
返回

  1. 小物体
    np.array((10, 150))
  2. 大物体
    np.array((1000, 150))

两种场景下处理一个批次的平均时间如下:

# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253

对于小物体,当

num_workers
增加时,每个批次的时间会按预期下降。但对于较大的物体,变化不大。我将其归因于工作进程必须序列化 np 对象,然后主进程将其反序列化。物体越大,花费的时间就越长。

但是,有了足够大的

num_worker
prefetch_factor
,数据加载器中的队列不应该总是被填满,这样数据获取就不会成为瓶颈吗?

此外,更改

prefetch_factor
不会改变任何内容。
prefetch_factor
有什么意义?该文档说主进程预加载了
num_worker * prefetch_factor
批次,但是您可以这样做,但对于减少瓶颈没有任何效果。

我在这个问题中添加了更详细的逐步分析,以供参考。

import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset


def collate_fn(records):
    # some custom collation function
    return records


class MyIterable(object):
    def __init__(self, n):
        self.n = n
        self.i = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.i < self.n:
            sleep(0.003125)  # simulates data fetch time
            # return np.random.random((10, 150))  # small data item
            return np.random.random((1000, 150))  # large data item
        else:
            raise StopIteration


class MyIterableDataset(IterableDataset):
    def __init__(self, n):
        super(MyIterableDataset).__init__()
        self.n = n

    def __iter__(self):
        return MyIterable(self.n)


def get_performance_metrics(num_workers):
    ds = MyIterableDataset(n=10000)

    if num_workers == 0:
        dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
    else:
        dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
                                         batch_size=128, collate_fn=collate_fn,
                                         multiprocessing_context='spawn')
    warmup = 5
    times = []
    t0 = time.perf_counter()
    for i, batch in enumerate(dl):
        sleep(0.05)  # simulates train step
        e = time.perf_counter()
        if i >= warmup:
            times.append(e - t0)
        t0 = time.perf_counter()
        if i >= 20:
            break
    print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')


if __name__ == '__main__':
    num_worker_options = [0, 2, 4, 6, 8]
    for n in num_worker_options:
        get_performance_metrics(n)
pytorch python-multiprocessing pytorch-dataloader
1个回答
0
投票

无论队列有多大,填满队列都需要时间。
如果您的迭代器以足够高的频率调用 next,它将赶上向队列添加数据的处理器。
理想情况下,要感觉没有瓶颈,time_to_add_data <= time_to_process_data.
如果 time_to_add_data > time_to_process_data,num_workers 可能会有所帮助,但是
如果 time_to_add_data >> time_to_process_data,num_workers 将无关紧要,第一个工作线程将成为处理的瓶颈。

© www.soinside.com 2019 - 2024. All rights reserved.