RuntimeError:预期用于4维权重[64、3、7、7]的4维输入,但改为使用大小[3、32、32]的3维输入]] << [

问题描述 投票:0回答:1
我对PyTorch和神经网络一般还是陌生的。我正在尝试从CIFAR-10数据集的torchvision实现resnet-50模型。

import torchvision import torch import torch.nn as nn from torch import optim import os import torchvision.transforms as transforms from torch.utils.data import DataLoader import numpy as np from collections import OrderedDict import matplotlib.pyplot as plt transformations=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) trainset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=True) testset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=False) trainloader=DataLoader(dataset=trainset,batch_size=4) testloader=DataLoader(dataset=testset,batch_size=4) inputs,labels=next(iter(trainset)) inputs.size() resnet=torchvision.models.resnet50(pretrained=True) if torch.cuda.is_available(): resnet=resnet.cuda() inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda() outputs=resnet(inputs)

输出

-------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-6-904acb410fe4> in <module>() 6 inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda() 7 ----> 8 outputs=resnet(inputs) 5 frames /usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight) 344 _pair(0), self.dilation, self.groups) 345 return F.conv2d(input, weight, self.bias, self.stride, --> 346 self.padding, self.dilation, self.groups) 347 348 def forward(self, input): RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead
由于某种原因,数据集是否存在问题,如果没有,我如何给出4维输入? ResNet-50的Torchvision实现是否不能用于CIFAR-10?

我对PyTorch和神经网络一般还是陌生的。我正在尝试从CIFAR-10数据集中的torchvision实现resnet-50模型。导入torchvision导入torch导入torch.nn作为nn,从...

deep-learning pytorch resnet transfer-learning torchvision
1个回答
1
投票
当前,您正在遍历数据集,这就是为什么要获取(三维)单幅图像的原因。实际上,您需要遍历数据加载器以获取4维图像批处理。因此,您只需要更改以下行:
© www.soinside.com 2019 - 2024. All rights reserved.