数据加载器/采样器/整理器根据样本内容(序列长度)创建批次

问题描述 投票:0回答:1

我正在使用数据集和数据加载器、整理函数和采样器将其他人的代码转换为更整洁的 torch-y 管道。虽然我以前做过这样的工作,但我不知道如何解决以下问题。

数据集包含句子作为样本。因此,每个样本都有许多单词(或

tokens
),我们可以通过天真地在空白(
sample.split()
)上分割样本来获得这些单词。这样的虚拟数据集可以如下所示:

from random import randint

from torch.utils.data import Dataset


class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]

现在我希望能够加载数据,以便最大。批次中的tokens数量不超过250。这意味着批次大小在迭代之间可能不同。一个批次可能包含两个样本,总令牌数不超过 250 个(例如 127 + 77),另一个批次可以包含三个样本(66+66+66)。现在,其核心功能相当简单。完整示例如下;没有通过长度排序或其他方式进行优化,但对于这个例子来说这是可以的。

问题是,如何将其集成到 PyTorch 生态系统中?批量大小经常用于指示

samples
的数量(就像在数据加载器中一样)。那么我应该在哪里插入它,或者我应该子类化什么,才能使其像常规数据加载器一样工作?

from random import randint

from torch.utils.data import Dataset

class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]


if __name__ == '__main__':
    dataset = DummyDataset()

    def get_batch(max_tokens: int = 250):
        data_idxs = list(range(len(dataset)))

        batch = []
        total_batch_len = 0
        while data_idxs:
            sample = dataset[data_idxs[0]]
            sample_len = len(sample.split())

            if total_batch_len + sample_len <= max_tokens:
                batch.append(sample)
                total_batch_len += sample_len
                data_idxs.pop(0)
            elif batch:
                yield batch
                batch = []
                total_batch_len = 0

        yield batch

    # Sanity check that we indeed get all items from the dataset
    num_samples = 0
    num_batches = 0
    for b in get_batch():
        num_samples += len(b)
        num_batches += 1

    print(f"Created {num_batches} batches")
    assert num_samples == len(dataset)

也许 torchtext 的 Iterator 和它的

batch_size_fn
可以提供帮助,但我没有使用它的经验(我应该在哪里添加它;它本身是一个数据加载器还是我应该在它周围包装一个数据加载器,等等)。

python pytorch batch-processing dataloader
1个回答
0
投票

阅读一些源代码后,似乎您可以在数据加载器的

batch_sampler
中使用任何迭代器。所以以下内容按预期工作。

from random import randint

from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader


class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]


class TokenBatchSampler:
    def __init__(self, max_tokens: int = 250):
        self.max_tokens = max_tokens
        self.batches = []
        self._prepare_dataset()

    def __len__(self) -> int:
        return len(self.batches)

    def __iter__(self):
        return iter(self.batches)

    def _prepare_dataset(self):
        data_idxs = list(range(len(dataset)))

        batches = []
        batch_idxs = []
        total_batch_len = 0
        while data_idxs:
            sample_idx = data_idxs[0]
            sample = dataset[sample_idx]
            sample_len = len(sample.split())

            if total_batch_len + sample_len <= self.max_tokens:
                batch_idxs.append(sample_idx)
                total_batch_len += sample_len
                data_idxs.pop(0)
            elif batch_idxs:
                batches.append(batch_idxs)
                batch_idxs = []
                total_batch_len = 0

        batches.append(batch_idxs)

        self.batches = batches


if __name__ == "__main__":
    dataset = DummyDataset()

    sampler = TokenBatchSampler()
    dataloader = DataLoader(dataset, batch_sampler=sampler)
    # Sanity check that we indeed get all items from the dataset
    for epoch in range(3):
        num_samples = 0
        num_batches = 0
        for b in dataloader:
            num_samples += len(b)
            num_batches += 1

        print(f"Created {num_batches} batches in epoch {epoch}")
        assert num_samples == len(dataset)

    print(f"DataLoader length {len(dataloader)}")

© www.soinside.com 2019 - 2024. All rights reserved.