PyTorch DataLoader 和 Matplotlib 的 Imshow 之间的图像分类任务问题

问题描述 投票:0回答:1

我目前正在研究涉及图像数据的二元分类任务。首先,我必须检查我的数据集。但是,我遇到了

DataLoader
的问题。

PyTorch官方网站上有这样写的

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

当他们设置

training data
时,他们将数据类型转换为张量。他们只是使用 imshow(matplotlib)。但是当我自己尝试这个过程时,错误
TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
困扰着我。

当我向 GPT4 询问这个问题时,它说“PyTorch 和 matplotlib 是兼容的。”然而,当我再次询问我提供的代码时,它提到:“在使用 imshow 之前,您需要将 PyTorch 张量转换为 NumPy 数组。”哪一种说法是准确的?

pytorch dataset classification dataloader
1个回答
0
投票

第二个应该是正确的说法。你应该改成这个

plt.imshow(img.squeeze().numpy(), cmap="gray")
© www.soinside.com 2019 - 2024. All rights reserved.