没有目标的验证数据

问题描述 投票:0回答:1

我有一个要由 CNN 模型分类的图像验证数据集。我想使用

pytorch
加载这些图像。
torchvision.datasets.ImageFolder()
函数不起作用,因为没有目标,因为数据集未分类。我假设我需要编写一个自定义数据集类,稍后我会将其放入
torch.utils.data.DataLoader()
。我在网上搜索过,但我仍然不太明白课程应该是什么样子。

我试过这个

import torch
from torch.utils.data import Dataset
from torchvision.io import read_image, ImageReadMode
import os


class Dset(Dataset):
    def __init__(self, dir: str, transform=None) -> None:
        self.transform = transform
        self.images = os.listdir(dir)
        self.dir = dir
    
    def __getitem__(self, index: int) -> torch.Tensor:
        image = read_image(f'{self.dir}/{self.images[index]}', mode=ImageReadMode.RGB) # upd
        if self.transform is not None:
            image = self.transform(image)
        return image

    def __len__(self) -> int:
        return len(self.images)

但是在这个单元格之后(所有图像都在

.data/
中)

from torchvision import transforms

batch_size = 64
transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
data = Dset('data', transform=transform)
trainloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

images, labels = iter(trainloader)

我遇到这个错误:

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

python machine-learning pytorch computer-vision
1个回答
0
投票

正如评论中所讨论的 - 问题是你的图像有一个 Alpha 通道。您可以修改

read_image
函数以从输入图像中删除 Alpha 通道,如下所示:

image = read_image(f'{self.dir}/{self.images[index]}', mode=ImageReadMode.RGB)

对于其他模式,您可以检查ImageReadMode类

© www.soinside.com 2019 - 2024. All rights reserved.