Pytorch默认数据加载器卡在大型图像分类训练集上

问题描述 投票:1回答:1

我正在Pytorch中训练图像分类模型,并使用它们的default data loader加载我的训练数据。我有一个非常大的训练数据集,因此通常每堂课有两千张样本图像。我过去训练的模型总共有约20万张图像,而没有出现问题。但是我发现,当总共拥有超过一百万个图像时,Pytorch数据加载器会卡住。

我相信当我呼叫datasets.ImageFolder(...)时代码正在挂起。当我按Ctrl-C时,这始终是输出:

Traceback (most recent call last):                                                                                                 │
  File "main.py", line 412, in <module>                                                                                            │
    main()                                                                                                                         │
  File "main.py", line 122, in main                                                                                                │
    run_training(args.group, args.num_classes)                                                                                     │
  File "main.py", line 203, in run_training                                                                                        │
    train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True)                                                      │
  File "main.py", line 236, in create_dataloader                                                                                   │
    dataset = datasets.ImageFolder(directory, trans)                                                                               │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__     │
    is_valid_file=is_valid_file)                                                                                                   │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__      │
    samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)                                                     │
  File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset  │
    for root, _, fnames in sorted(os.walk(d)):                                                                                     │
  File "/usr/lib/python3.5/os.py", line 380, in walk                                                                               │
    is_dir = entry.is_dir()                                                                                                        │
Keyboard Interrupt                                                                                                                       

我以为某个地方可能会出现死锁,但是根据Ctrl-C的堆栈输出,它看起来并不像在等待锁。因此,我认为数据加载器速度很慢,因为我试图加载更多数据。我让它运行了大约2天,但没有任何进展,在加载的最后2个小时中,我检查了RAM使用量是否保持不变。我还能够在过去不到两个小时的时间内加载超过20万张图像的训练数据集。我还尝试将我的GCP机器升级为具有32个内核,4个GPU和超过100GB的RAM,但是似乎是在加载一定数量的内存之后,数据加载器被卡住了。

我很困惑在目录中循环时数据加载器如何卡住,但我仍然不确定它是否卡住或非常缓慢。有什么方法可以更改Pytortch数据加载器,使其能够处理100万以上的图像进行训练?任何调试建议也表示赞赏!

谢谢!

deep-learning computer-vision classification pytorch dataloader
1个回答
0
投票

DataLoader没问题,torchvision.datasets.ImageFolder没问题,以及它如何工作(以及为什么您拥有的数据越多,其工作情况就越差)。

它挂在这行,如您的错误所示:

for root, _, fnames in sorted(os.walk(d)): 

可以找到来源here

潜在的问题是将每个path和相应的label保留在巨型list中,请参见下面的代码(为简洁起见,删除了一些内容:]]

def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)
    # Iterate over all subfolders which were found previously
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target) # Create path to this subfolder
        # Assuming it is directory (which usually is the case)
        for root, _, fnames in sorted(os.walk(d, followlinks=True)):
            # Iterate over ALL files in this subdirectory
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                # Assuming it is correctly recognized as image file
                item = (path, class_to_idx[target])
                # Add to path with all images
                images.append(item)

    return images

显然,图像将包含一百万个字符串(也很长)和相应的类的int,这肯定很多,并且取决于RAM和CPU。

但是您可以创建自己的数据集(前提是您事先更改了图像的名称),所以[dataset不会占用[[不占用内存]。]。

设置数据结构

您的文件夹结构应如下所示:

root class1 class2 class3 ...

使用您有/需要的班级。

现在每个class应该具有以下数据:

class1 0.png 1.png 2.png ...

鉴于您可以继续创建数据集。

创建数据集

torch.utils.data.Dataset下面使用PIL打开图像,但是您可以用另一种方式来做:

import os import pathlib import torch from PIL import Image class ImageDataset(torch.utils.data.Dataset): def __init__(self, root: str, folder: str, klass: int, extension: str = "png"): self._data = pathlib.Path(root) / folder self.klass = klass self.extension = extension # Only calculate once how many files are in this folder # Could be passed as argument if you precalculate it somehow # e.g. ls | wc -l on Linux self._length = sum(1 for entry in os.listdir(self._data)) def __len__(self): # No need to recalculate this value every time return self._length def __getitem__(self, index): # images always follow [0, n-1], so you access them directly return Image.open(self._data / "{}.{}".format(str(index), self.extension))

现在您可以轻松创建数据集(假定文件夹结构与上面的一样:

root = "/path/to/root/with/images" dataset = ( ImageDataset(root, "class0", 0) + ImageDataset(root, "class1", 1) + ImageDataset(root, "class2", 2) )

您可以根据需要添加任意多个具有指定类的datasets,可以循环执行。

最后,照常使用torch.utils.data.DataLoader,例如:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

© www.soinside.com 2019 - 2024. All rights reserved.