我正在尝试使用以下代码加载包含图像的本地数据集(总共约 225 张图像):
# Set the batch size
BATCH_SIZE = 32
# Create data loaders
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
transform=manual_transforms, # use manually created transforms
batch_size=BATCH_SIZE
)
# Get a batch of images
image_batch, label_batch = next(iter(train_dataloader)) # why it takes so much time? what can
I do about it?
我的问题涉及代码的最后一行和
train_dataloader
中的迭代,这需要很长的执行时间。为什么会这样呢?我只有 225 张图片。
编辑:
数据加载器的代码可以在以下链接中找到。
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import pdb
NUM_WORKERS = os.cpu_count()
def create_dataloaders(
train_dir: str,
test_dir: str,
transform: transforms.Compose,
batch_size: int,
num_workers: int=NUM_WORKERS
):
# Use ImageFolder to create dataset(s)
train_data = datasets.ImageFolder(train_dir, transform=transform)
test_data = datasets.ImageFolder(test_dir, transform=transform)
# Get class names
class_names = train_data.classes
# Turn images into data loaders
train_dataloader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
test_dataloader = DataLoader(
test_data,
batch_size=batch_size,
shuffle=False, # don't need to shuffle test data
num_workers=num_workers,
pin_memory=True,
)
return train_dataloader, test_dataloader, class_names
next(iter(train_dataloader)
调用缓慢的主要原因是由于多处理 - 或者不正确地使用多处理。当num_workers > 0
时,对iter(train_dataloader)
的调用将分叉主Python进程(当前脚本),这意味着在调用iter(...)
之前导入期间发生的任何耗时代码,例如任何类型的文件加载这发生在全局范围内(!),将导致额外的速度减慢。也就是说,除了进程创建时间以及调用 next(iter(...))
时需要发生的数据序列化和反序列化之外,还有额外的时间。
您可以通过在调用
time.sleep(5)
之前在全局范围内的任意位置添加 next(iter(train_dataloader))
来验证这一点。然后您会看到调用将比原来慢 5 秒。
不幸的是,我不知道如何修复火炬数据加载器的这个问题,除了(1)设置
num_workers=0
,或(2)确保在导入主脚本期间没有耗时的代码,或者 (3) 不要使用 torch DataLoader,而是使用 HuggingFace 数据集接口。