KMNIST:运行时错误:形状 [1, 28, 28] 的输出与广播形状 [3, 28, 28] 不匹配

问题描述 投票:0回答:2

当我运行 Wide Resnet 代码时,出现运行时错误。 RuntimeError:形状 [1, 28, 28] 的输出与广播形状 [3, 28, 28] 不匹配 我尝试了网上提供的几种解决方案,但都没有解决,它们都显示了其他问题。我不知道如何解决它。所有相关的运行时错误都显示在代码上。

  elif args.data == 'kmnist':
      normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                       std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

      if args.data_augmentation:
          transform_train = transforms.Compose([
              transforms.RandomCrop(32, padding=4),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
              normalize,
          ])
      else:
          transform_train = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize([0.5], [0.5])
          ])

      # If following portion, it would be another runtime error
      # RuntimeError: Given groups=1, weight of size 16 3 3 3, expected 
      #input[128, 1, 28, 28] to have 3 channels, but got 1 channels instead
      transform_test = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize([0.5], [0.5])
      ])


      # If I tried following portion, it would be 
      # RuntimeError: output with shape [1, 28, 28] doesn't match the 
      # broadcast shape [3, 28, 28]
      # transform_test = transforms.Compose([
      #     transforms.ToTensor(),
      #     normalize
      # ])

      # If I tried following portion of the code, I received
      # AttributeError: Can't pickle local object 'get_data_loaders.<locals>.<lambda>'
      # transform_test = transforms.Compose([
      #     transforms.ToTensor(),
      #     transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
      #     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
      # ])

      kwargs = {'num_workers': 1, 'pin_memory': True}
      train_loader = torch.utils.data.DataLoader(
          KMNISTRandomLabels(root='./kmnistdata', train=True, download=True,
                            transform=transform_train, num_classes=args.num_classes,
                            corrupt_prob=args.label_corrupt_prob),
          batch_size=args.batch_size, shuffle=shuffle_train, **kwargs)
      val_loader = torch.utils.data.DataLoader(
          KMNISTRandomLabels(root='./kmnistdata', train=False,
                            transform=transform_test, num_classes=args.num_classes,
                            corrupt_prob=args.label_corrupt_prob),
          batch_size=args.batch_size, shuffle=False, **kwargs)

      return train_loader, val_loader
"""
Fashion-MNIST dataset, with support for random labels
"""
import numpy as np

import torch
import torchvision.datasets as datasets


class FashionMNISTRandomLabels(datasets.FashionMNIST):
  """Fashion-MNIST dataset, with support for randomly corrupt labels.

  Params
  ------
  corrupt_prob: float
    Default 0.0. The probability of a label being replaced with
    random label.
  num_classes: int
    Default 10. The number of classes in the dataset.
  """
  def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
    super(FashionMNISTRandomLabels, self).__init__(**kwargs)
    self.n_classes = num_classes
    if corrupt_prob > 0:
      self.corrupt_labels(corrupt_prob)

  def corrupt_labels(self, corrupt_prob):
    labels = np.array(self.train_labels if self.train else self.test_labels)
    np.random.seed(12345)
    mask = np.random.rand(len(labels)) <= corrupt_prob
    rnd_labels = np.random.choice(self.n_classes, mask.sum())
    labels[mask] = rnd_labels
    # we need to explicitly cast the labels from npy.int64 to
    # builtin int type, otherwise pytorch will fail...
    labels = [int(x) for x in labels]

    if self.train:
      self.train_labels = labels
    else:
      self.test_labels = labels
pytorch torchvision
2个回答
0
投票

如果你这样设置就可以工作:

    [...]
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: x.repeat(3, 1, 1)),

0
投票

您提供了一个均值和标准差,其中在您的列表计算中定义了 3 个元素,

normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                       std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

相反,因为您有一个尺寸为 [1,28,28] 的图像,所以您应该进行只有 1 个均值和 1 个标准差的归一化变换,请更新您的代码以具有以下内容:

normalize = transforms.Normalize(mean=X,std=Y)
© www.soinside.com 2019 - 2024. All rights reserved.