防止PyTorch数据集迭代超过数据集的长度

问题描述 投票:0回答:1

我正在使用自定义PyTorch数据集,具有以下内容:

class ImageDataset(Dataset):
    def __init__(self, input_dir, input_num, input_format, transform=None):
        self.input_num = input_num
        # etc
    def __len__ (self):
        return self.input_num
    def __getitem__(self,idx):
        targetnum = idx % self.input_num
        # etc

但是,当我遍历此数据集时,迭代将循环回数据集的开头,而不是终止于数据集的末尾。这有效地成为迭代器中的无限循环,epoch print语句永远不会出现在后续的纪元中。

train_dataset=ImageDataset(input_dir = 'path/to/directory', 
                           input_num = 300, input_format = "mask") # Size 300
num_epochs = 10
for epoch in range(num_epochs):
    print("EPOCH " + str(epoch+1) + "\n")
    num = 0
    for data in train_dataset:
        print(num, end=" ")
        num += 1
        # etc

打印输出(...表示两者之间的值):

EPOCH 1
0 1 2 3 4 5 6 7 ... 298 299 300 301 302 303 304 305 ... 597 598 599 600 601 602 603 604 ...

为什么数据集上的基本迭代继续超过DataSet的已定义__len__,以及如何在使用此方法时确保数据集上的迭代在达到数据集的长度后终止(或者手动迭代数据集的范围)长度唯一的解决方案)?

谢谢。

python pytorch
1个回答
0
投票

Dataset类没有实现StopIteration信号。

for循环侦听StopIteration。 for语句的目的是遍历迭代器提供的序列,异常用于表示迭代器现在已完成...

更多:Why does next raise a 'StopIteration', but 'for' do a normal return? | The Iterator Protocol

© www.soinside.com 2019 - 2024. All rights reserved.