Pytorch转换。RandomRotation()在Google Colab上不起作用

问题描述 投票:1回答:1

通常,我正在计算机上进行字母和数字识别,我想将项目移至Colab,但不幸的是出现了错误(您可以在下面看到错误)。经过一些调试后,我发现哪条线给了我错误。

transforms.RandomRotation(degrees=(90, -90))

以下我编写了简单的抽象代码来显示此错误。此代码在colab上不起作用,但在我自己的计算机环境下可以正常工作。问题可能与pytorch库的不同版本有关,我在计算机上的版本为1.3.1, colab使用版本1.4.0。

import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt   
    transformOpt = transforms.Compose([
            transforms.RandomRotation(degrees=(90, -90)),
            transforms.ToTensor()
        ])

    train_set = datasets.MNIST(
        root='', train=True, transform=transformOpt, download=True)
    test_set = datasets.MNIST(
        root='', train=False, transform=transformOpt, download=True)


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=100,
        shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=100,
        shuffle=False)

    images, labels = next(iter(train_loader))
    plt.imshow(images[0].view(28, 28), cmap="gray")
    plt.show()

当我在Google Colab上执行上述示例代码时遇到的完整错误。

TypeError                                 Traceback (most recent call last)

<ipython-input-1-8409db422154> in <module>()
     24     shuffle=False)
     25 
---> 26 images, labels = next(iter(train_loader))
     27 plt.imshow(images[0].view(28, 28), cmap="gray")
     28 plt.show()

10 frames

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
     95 
     96         if self.transform is not None:
---> 97             img = self.transform(img)
     98 
     99         if self.target_transform is not None:

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72 

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)    1001         angle = self.get_params(self.degrees)    1002 
-> 1003         return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)    1004     1005     def
__repr__(self):

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill)
    727         fill = tuple([fill] * 3)
    728 
--> 729     return img.rotate(angle, resample, expand, center, fillcolor=fill)
    730 
    731 

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor)    2003         w, h = nw, nh    2004 
-> 2005         return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor)    2006     2007     def save(self,    fp, format=None, **params):

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor)    2297             raise ValueError("missing method data")    2298 
-> 2299         im = new(self.mode, size, fillcolor)    2300         if method == MESH:    2301             # list of quads

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color)    2503         im.palette = ImagePalette.ImagePalette()    2504         color = im.palette.getcolor(color)
-> 2505     return im._new(core.fill(mode, size, color))    2506     2507 

TypeError: function takes exactly 1 argument (3 given)
python machine-learning pytorch google-colaboratory
1个回答
0
投票

您绝对正确。 torchvision 0.5在RandomRotation()参数中的fill中有一个错误,可能是由于Pillow版本不兼容所致。 issue现已修复(PR#1760),将在下一版本中解决。

暂时将fill=(0,)添加到RandomRotation转换中进行修复。

transforms.RandomRotation(degrees=(90, -90), fill=(0,))
© www.soinside.com 2019 - 2024. All rights reserved.