我正在Pytorch中训练图像分类模型,并使用它们的default data loader加载我的训练数据。我有一个非常大的训练数据集,因此通常每堂课有两千张样本图像。我过去训练的模型总共有约20万张图像,而没有出现问题。但是我发现,当总共拥有超过一百万个图像时,Pytorch数据加载器会卡住。
我相信当我呼叫datasets.ImageFolder(...)
时代码正在挂起。当我按Ctrl-C时,这始终是输出:
Traceback (most recent call last): │
File "main.py", line 412, in <module> │
main() │
File "main.py", line 122, in main │
run_training(args.group, args.num_classes) │
File "main.py", line 203, in run_training │
train_loader = create_dataloader(traindir, tfm.train_trans, shuffle=True) │
File "main.py", line 236, in create_dataloader │
dataset = datasets.ImageFolder(directory, trans) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 209, in __init__ │
is_valid_file=is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 94, in __init__ │
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) │
File "/home/username/.local/lib/python3.5/site-packages/torchvision/datasets/folder.py", line 47, in make_dataset │
for root, _, fnames in sorted(os.walk(d)): │
File "/usr/lib/python3.5/os.py", line 380, in walk │
is_dir = entry.is_dir() │
Keyboard Interrupt
我以为某个地方可能会出现死锁,但是根据Ctrl-C的堆栈输出,它看起来并不像在等待锁。因此,我认为数据加载器速度很慢,因为我试图加载更多数据。我让它运行了大约2天,但没有任何进展,在加载的最后2个小时中,我检查了RAM使用量是否保持不变。我还能够在过去不到两个小时的时间内加载超过20万张图像的训练数据集。我还尝试将我的GCP机器升级为具有32个内核,4个GPU和超过100GB的RAM,但是似乎是在加载一定数量的内存之后,数据加载器被卡住了。
我很困惑在目录中循环时数据加载器如何卡住,但我仍然不确定它是否卡住或非常缓慢。有什么方法可以更改Pytortch数据加载器,使其能够处理100万以上的图像进行训练?任何调试建议也表示赞赏!
谢谢!
DataLoader
没问题,torchvision.datasets.ImageFolder
没问题,以及它如何工作(以及为什么您拥有的数据越多,其工作情况就越差)。
它挂在这行,如您的错误所示:
for root, _, fnames in sorted(os.walk(d)):
可以找到来源here。
潜在的问题是将每个path
和相应的label
保留在巨型list
中,请参见下面的代码(为简洁起见,删除了一些内容:]]
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): images = [] dir = os.path.expanduser(dir) # Iterate over all subfolders which were found previously for target in sorted(class_to_idx.keys()): d = os.path.join(dir, target) # Create path to this subfolder # Assuming it is directory (which usually is the case) for root, _, fnames in sorted(os.walk(d, followlinks=True)): # Iterate over ALL files in this subdirectory for fname in sorted(fnames): path = os.path.join(root, fname) # Assuming it is correctly recognized as image file item = (path, class_to_idx[target]) # Add to path with all images images.append(item) return images
显然,图像将包含一百万个字符串(也很长)和相应的类的
int
,这肯定很多,并且取决于RAM和CPU。
但是您可以创建自己的数据集(前提是您事先更改了图像的名称),所以[dataset
不会占用[[不占用内存]。]。
root
class1
class2
class3
...
使用您有/需要的班级。现在每个
class
应该具有以下数据:
class1
0.png
1.png
2.png
...
鉴于您可以继续创建数据集。创建数据集
torch.utils.data.Dataset
下面使用PIL
打开图像,但是您可以用另一种方式来做:import os
import pathlib
import torch
from PIL import Image
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, root: str, folder: str, klass: int, extension: str = "png"):
self._data = pathlib.Path(root) / folder
self.klass = klass
self.extension = extension
# Only calculate once how many files are in this folder
# Could be passed as argument if you precalculate it somehow
# e.g. ls | wc -l on Linux
self._length = sum(1 for entry in os.listdir(self._data))
def __len__(self):
# No need to recalculate this value every time
return self._length
def __getitem__(self, index):
# images always follow [0, n-1], so you access them directly
return Image.open(self._data / "{}.{}".format(str(index), self.extension))
现在您可以轻松创建数据集(假定文件夹结构与上面的一样:
root = "/path/to/root/with/images" dataset = ( ImageDataset(root, "class0", 0) + ImageDataset(root, "class1", 1) + ImageDataset(root, "class2", 2) )
您可以根据需要添加任意多个具有指定类的datasets
,可以循环执行。最后,照常使用
torch.utils.data.DataLoader
,例如:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)