我有一个要由 CNN 模型分类的图像验证数据集。我想使用
pytorch
加载这些图像。 torchvision.datasets.ImageFolder()
函数不起作用,因为没有目标,因为数据集未分类。我假设我需要编写一个自定义数据集类,稍后我会将其放入torch.utils.data.DataLoader()
。我在网上搜索过,但我仍然不太明白课程应该是什么样子。
我试过这个
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image, ImageReadMode
import os
class Dset(Dataset):
def __init__(self, dir: str, transform=None) -> None:
self.transform = transform
self.images = os.listdir(dir)
self.dir = dir
def __getitem__(self, index: int) -> torch.Tensor:
image = read_image(f'{self.dir}/{self.images[index]}', mode=ImageReadMode.RGB) # upd
if self.transform is not None:
image = self.transform(image)
return image
def __len__(self) -> int:
return len(self.images)
但是在这个单元格之后(所有图像都在
.data/
中)
from torchvision import transforms
batch_size = 64
transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
data = Dset('data', transform=transform)
trainloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
images, labels = iter(trainloader)
我遇到这个错误:
TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
正如评论中所讨论的 - 问题是你的图像有一个 Alpha 通道。您可以修改
read_image
函数以从输入图像中删除 Alpha 通道,如下所示:
image = read_image(f'{self.dir}/{self.images[index]}', mode=ImageReadMode.RGB)
对于其他模式,您可以检查ImageReadMode类。