(Pytorch) mat1 和 mat2 形状不能相乘(212992x13 和 1280x3)

问题描述 投票:0回答:1

我正在尝试使用自定义数据集在 Pytorch 预训练模型上进行迁移学习。 目前,我收到错误如下

mat1 and mat2 shapes cannot be multiplied (212992x13 and 1280x3) 
在训练自定义模型期间。

当我尝试使用高效网络时,下面的代码可以工作并且训练成功,但是当我使用像挤压网络这样的模型时,我收到错误

作品:

weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
model = torchvision.models.efficientnet_b0(weights=weights).to(device)

不起作用:

weights = torchvision.models.SqueezeNet1_0_Weights.DEFAULT
model = torchvision.models.squeezenet1_0(weights=weights).to(device)

火车:

auto_transforms = weights.transforms()
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir, test_dir=test_dir, transform=auto_transforms, batch_size=32)

for param in model.features.parameters():
    param.requires_grad = False #Freeze layers

torch.manual_seed(42)
output_shape = len(class_names)
model.classifier = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2, inplace=True),
    torch.nn.Linear(in_features=1280,
                    out_features=output_shape,
                    bias=True)).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#ERROR DURING TRAIN
results = engine.train(model=model, train_dataloader=train_dataloader, test_dataloader=test_dataloader, optimizer=optimizer, loss_fn=loss_fn, epochs=100, device=device)

训练图像尺寸为512x512

为了确保这不是转换问题,我使用了自动转换,但问题仍然存在。

虽然存在类似的主题mat1和mat2形状不能相乘(128x4和128x64),但它完全基于创建新的顺序模型,而我正在尝试在预训练模型上使用迁移学习。

machine-learning deep-learning pytorch transfer-learning torchvision
1个回答
0
投票

如果你用恒等函数替换你的分类器,你就会看到问题是什么:

model.classifier = nn.Identity()
model(torch.rand(2,3,512,512)).shape
torch.Size([2, 492032])

分类器线性层的

in_features
应该是
492032
,而不是
1280
。另外,如果你与
SqueezeNet
的源代码进行比较,你会发现
model.classfier
不包含线性层,而是包含卷积层和池化层,第81行

final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential(
   nn.Dropout(p=dropout), 
   final_conv, 
   nn.ReLU(inplace=True), 
   nn.AdaptiveAvgPool2d((1, 1)))
© www.soinside.com 2019 - 2024. All rights reserved.