我正在使用数据集和数据加载器、整理函数和采样器将其他人的代码转换为更整洁的 torch-y 管道。虽然我以前做过这样的工作,但我不知道如何解决以下问题。
数据集包含句子作为样本。因此,每个样本都有许多单词(或
tokens
),我们可以通过天真地在空白(sample.split()
)上分割样本来获得这些单词。这样的虚拟数据集可以如下所示:
from random import randint
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
现在我希望能够加载数据,以便最大。批次中的tokens数量不超过250。这意味着批次大小在迭代之间可能不同。一个批次可能包含两个样本,总令牌数不超过 250 个(例如 127 + 77),另一个批次可以包含三个样本(66+66+66)。现在,其核心功能相当简单。完整示例如下;没有通过长度排序或其他方式进行优化,但对于这个例子来说这是可以的。
问题是,如何将其集成到 PyTorch 生态系统中?批量大小经常用于指示
samples
的数量(就像在数据加载器中一样)。那么我应该在哪里插入它,或者我应该子类化什么,才能使其像常规数据加载器一样工作?
from random import randint
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
if __name__ == '__main__':
dataset = DummyDataset()
def get_batch(max_tokens: int = 250):
data_idxs = list(range(len(dataset)))
batch = []
total_batch_len = 0
while data_idxs:
sample = dataset[data_idxs[0]]
sample_len = len(sample.split())
if total_batch_len + sample_len <= max_tokens:
batch.append(sample)
total_batch_len += sample_len
data_idxs.pop(0)
elif batch:
yield batch
batch = []
total_batch_len = 0
yield batch
# Sanity check that we indeed get all items from the dataset
num_samples = 0
num_batches = 0
for b in get_batch():
num_samples += len(b)
num_batches += 1
print(f"Created {num_batches} batches")
assert num_samples == len(dataset)
也许 torchtext 的 Iterator 和它的
batch_size_fn
可以提供帮助,但我没有使用它的经验(我应该在哪里添加它;它本身是一个数据加载器还是我应该在它周围包装一个数据加载器,等等)。
阅读一些源代码后,似乎您可以在数据加载器的
batch_sampler
中使用任何迭代器。所以以下内容按预期工作。
from random import randint
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
class TokenBatchSampler:
def __init__(self, max_tokens: int = 250):
self.max_tokens = max_tokens
self.batches = []
self._prepare_dataset()
def __len__(self) -> int:
return len(self.batches)
def __iter__(self):
return iter(self.batches)
def _prepare_dataset(self):
data_idxs = list(range(len(dataset)))
batches = []
batch_idxs = []
total_batch_len = 0
while data_idxs:
sample_idx = data_idxs[0]
sample = dataset[sample_idx]
sample_len = len(sample.split())
if total_batch_len + sample_len <= self.max_tokens:
batch_idxs.append(sample_idx)
total_batch_len += sample_len
data_idxs.pop(0)
elif batch_idxs:
batches.append(batch_idxs)
batch_idxs = []
total_batch_len = 0
batches.append(batch_idxs)
self.batches = batches
if __name__ == "__main__":
dataset = DummyDataset()
sampler = TokenBatchSampler()
dataloader = DataLoader(dataset, batch_sampler=sampler)
# Sanity check that we indeed get all items from the dataset
for epoch in range(3):
num_samples = 0
num_batches = 0
for b in dataloader:
num_samples += len(b)
num_batches += 1
print(f"Created {num_batches} batches in epoch {epoch}")
assert num_samples == len(dataset)
print(f"DataLoader length {len(dataloader)}")