cnn代码有问题,可能是班级老师的问题

问题描述 投票:0回答:1

我建立了一个CNN

import numpy as np
import torch
import torch.nn as nn

class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
    self.n = 10
    kernel_size = 3
    padding = (kernel_size - 1) / 2
    self.conv1 = nn.Conv2d(in_channels=3,out_channels=self.n,kernel_size=kernel_size,stride = (2,2),padding=padding)

    self.conv2 = nn.Conv2d(in_channels=self.n,out_channels=2*self.n,kernel_size=kernel_size,stride = (2,2),padding=padding)
        
    self.conv3 = nn.Conv2d(in_channels=2*self.n,out_channels=4*self.n,kernel_size=kernel_size,stride = (2,2),padding=padding)
    
    self.conv4 = nn.Conv2d(in_channels=4*self.n,out_channels=8*self.n,kernel_size=kernel_size,stride = (2,2),padding=padding)
    
    self.fc1 = nn.Linear(8 * self.n * 7 * 4, 100)
    self.fc2 = nn.Linear(100, 2) 

 def forward(self,inp):
   out = nn.functional.relu(self.conv1(inp))
   out = nn.functional.relu(self.conv2(out))
   out = nn.functional.relu(self.conv3(out))
   out = nn.functional.relu(self.conv4(out))

   out = out. View(-1, 8 * self.n * 7 * 4)
   out = nn.functional.relu(self.fc1(out)) 
   out = self.fc2(out)
    
   return out

输入数据inp是形状为(N,3,448,224)的张量,输出形状为(N,2)。

问题是我收到错误:

TypeError: conv2d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (float, float)!, !tuple of (int, int)!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (float, float)!, !tuple of (int, int)!, int)

有什么建议如何解决吗?

machine-learning pytorch conv-neural-network
1个回答
0
投票

代码有两个问题:

  1. padding 需要一个 int,但它是一个 float。

    填充 = int(kernel_size / 2)

  2. 视图是用小写字母写的:)

         out = out.view(-1, 8 * self.n * 7 * 4)
    
© www.soinside.com 2019 - 2024. All rights reserved.