为什么这个数据集尽量重复过去的最后一个元素
from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
def __init__(self, dct):
self.dct = dct
self.mapping = dict(enumerate(dct))
def __getitem__(self, index):
return self.dct[self.mapping[index]]
def __len__(self):
print('called')
return len(self.dct)
ds = DumbDataset({'a': 'aword', 'b': 'another_words'})
for k in ds: print(k)
这引发KeyError异常:2,这一点我不明白,因为对象的长度2.不宜迭代得到的StopIteration一旦它耗尽?
为什么你的代码引发KeyError
的原因是,Dataset
does not implement __iter__()
,因此在使用时,一个for循环的Python回落到起始于指数0
并呼吁__getitem__
直到IndexError
提高,为讨论here。您可以修改DumbDataset
有它引发IndexError
这样的工作时,指数超出范围
def __getitem__(self, index):
if index >= len(self): raise IndexError
return self.dct[self.mapping[index]]
然后你的循环
for k in ds:
print(k)
如你预期会工作。在另一方面,对于火炬数据集的典型模板是您可以通过它们与分度环
for i in range(len(ds)):
k = ds[k]
print(k)
或者说你包起来返回分批元素DataLoader
generator = DataLoader(ds)
for k in generator:
print(k)