假设我们有一个音频分类任务(AudioMNIST)。
我的管道和我见过的其他管道包括以下步骤:
我看到了方案:
要么:
或者:
整理者应该做什么,不应该做什么? (主要问题。) 正确的方案是什么?
你已经用 pytorch 标记了这个,所以我会给出 pytorch 答案。
Pytorch 数据实用程序有一个
Dataset
和一个 DataLoader
。 tl;dr, Dataset
处理加载单个示例,而 DataLoader
处理批处理和任何批量处理。
Dataset
有两种方法,__len__
用于确定数据集中的项目数量,__getitem__
用于加载单个项目。
class MyDataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, index):
...
向
DataLoader
传递来自 Dataset
的输出列表(即 batch_input = [dataset.__getitem__(i) for i in idxs]
)。批量输入被发送到 collate_fn
的 DataLoader
。
def my_collate_fn(batch):
...
dataloader = DataLoader(my_dataset, batch_size, collate_fn=my_collate_fn)
考虑到在哪里做什么,
Dataset
应该处理加载单个示例。 Dataset
将被并行调用,因此受 CPU 限制的任务应该放入 Dataset
。从磁盘加载(如果适用)通常也在 Dataset
中完成。
collate_fn
负责将 Dataset
的输出列表转换为模型想要的任何格式。由于 DataLoader
处理一批数据,因此应用批处理步骤可以更有效。堆叠张量、填充长度、生成掩模或其他批量张量操作在collate_fn
中效果很好。
一般来说,将
Dataset
视为在单个示例上运行多进程,而 DataLoader
在一批示例上运行单进程。