火炬数据集循环太远

问题描述 投票:2回答:1

为什么这个数据集尽量重复过去的最后一个元素

from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
    def __init__(self, dct):
        self.dct = dct
        self.mapping = dict(enumerate(dct))
    def __getitem__(self, index):
        return self.dct[self.mapping[index]]

    def __len__(self):
        print('called')
        return len(self.dct)

ds = DumbDataset({'a': 'aword', 'b': 'another_words'})

for k in ds: print(k)

这引发KeyError异常:2,这一点我不明白,因为对象的长度2.不宜迭代得到的StopIteration一旦它耗尽?

python pytorch
1个回答
3
投票

为什么你的代码引发KeyError的原因是,Dataset does not implement __iter__(),因此在使用时,一个for循环的Python回落到起始于指数0并呼吁__getitem__直到IndexError提高,为讨论here。您可以修改DumbDataset有它引发IndexError这样的工作时,指数超出范围

def __getitem__(self, index):
    if index >= len(self): raise IndexError
    return self.dct[self.mapping[index]]

然后你的循环

for k in ds:
    print(k)

如你预期会工作。在另一方面,对于火炬数据集的典型模板是您可以通过它们与分度环

for i in range(len(ds)):
    k = ds[k]
    print(k)

或者说你包起来返回分批元素DataLoader

generator = DataLoader(ds)
for k in generator:
    print(k)
© www.soinside.com 2019 - 2024. All rights reserved.