Pytorch CustomDataset TypeError:列表索引必须是整数或切片,而不是列表

问题描述 投票:0回答:1

我在使用 pytorch 自定义数据集时遇到错误。这个问题对我来说真的很奇怪,因为它正在工作,我没有更改代码上的任何内容。这是场景:

  1. 构建深度学习模型后,我通过 10 到 100 个时期的训练来测试它。它工作正常,但我发现该模型需要进行更多轮数的训练才能获得更好的结果。

  2. 所以,我将纪元数更改为 500。我的 GPU 崩溃了,可能是因为我每 10 纪元打印一次结果,并且内存不足(我不知道真正的问题是什么)

  3. 现在重新启动 GPU 后,jupiter 笔记本向我抛出服务器错误,状态代码为 500

  4. 我在互联网上搜索并通过运行以下命令找到了我的 jupiter 笔记本的解决方案:

    pip install --upgrade nbconvert

  5. 之后,代码就不再起作用了。我尝试调试它,但发现了一些奇怪的东西:

  • 如果我将自定义数据集类放入 python 文件中并从 jupiter 笔记本中调用它,则会出现以下错误(函数 getitem 中的参数 idx 是一个列表)
  • 如果我将类直接放入 jupiter 笔记本单元中并从该 jupiter 笔记本中调用它,则自定义数据集类可以工作(函数 getitem 中的参数 idx 是一个整数)

提前感谢您的回答。

这是自定义数据集类:

将其放入 src/custom_dataset.py 例如

import os
from natsort import natsorted
from PIL import Image
# Let's see if we have an available GPU
from datasets import Dataset

class LoadPairedDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        print(idx)
        img_name = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image

这是我用来调用自定义数据集的代码:

将其放入木星笔记本单元中

# Imports
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from src.custom_dataset import LoadPairedDataset
 
# Define your own class LoadFromFolder
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        print(idx)  
        img_name = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image
    

base_path = "../lol-custom"

# dataloader = {"train_n": None, "train_p": None}
transform = transforms.Compose([
            transforms.ToTensor()                       
        ])

train_data = CustomImageDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=5,
    sampler=None,
    num_workers=0
)
# The output will be:
# 0 1 2 3 4 from the print(idx) in the __getitem__ function in CustomImageDataset class
# torch.Size([5, 3, 400, 600]) from the below print
print(next(iter(dataloader)).shape) # This will print 0 1 2 3 4 
print("######### The below throw an error ##############")

train_data = LoadPairedDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=5,
    sampler=None,
    num_workers=0
)
# The output will be:
# [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
# THEN ERROR: TypeError: list indices must be integers or slices, not list
print(next(iter(dataloader)).shape) 

我正在使用 torch 2.01 和 python 3.9.18

最后,这是错误的输出和堆栈跟踪

0
1
2
3
4
torch.Size([5, 3, 400, 600])
######### The below throw an error ##############
[0, 1, 2, 3, 4]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 61
     52 dataloader = torch.utils.data.DataLoader(
     53     train_data,
     54     batch_size=5,
     55     sampler=None,
     56     num_workers=0
     57 )
     58 # This output will be:
     59 # [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
     60 # THEN ERROR: TypeError: list indices must be integers or slices, not list
---> 61 print(next(iter(dataloader)).shape)

File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
    675 def _next_data(self):
    676     index = self._next_index()  # may raise StopIteration
--> 677     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    678     if self._pin_memory:
    679         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     47 if self.auto_collation:
     48     if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
     51         data = [self.dataset[idx] for idx in possibly_batched_index]

File ~\anaconda3\envs\mmie\lib\site-packages\datasets\arrow_dataset.py:2807, in Dataset.__getitems__(self, keys)
   2805 def __getitems__(self, keys: List) -> List:
   2806     """Can be used to get a batch using a list of integers indices."""
-> 2807     batch = self.__getitem__(keys)
   2808     n_examples = len(batch[next(iter(batch))])
   2809     return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]

File ~\Projects\mmie\src\custom_dataset.py:21, in LoadPairedDataset.__getitem__(self, idx)
     19 def __getitem__(self, idx):
     20     print(idx)
---> 21     img_name = os.path.join(self.root_dir, self.images[idx])
     22     image = Image.open(img_name)
     24     if self.transform:

TypeError: list indices must be integers or slices, not list


python jupyter-notebook pytorch torch
1个回答
0
投票

这可能是因为您的自定义数据集继承自

Dataset
类,但
Datset
的含义发生了变化。

在单独的文件中时,您将

Dataset
定义为
from datasets import Dataset
但在 jupyter 单元中,
Dataset
from torch.utils.data import Dataset
,这显然是不同的。我建议您也保留在单独文件中有效的定义

© www.soinside.com 2019 - 2024. All rights reserved.