我正在尝试对视觉变压器模型进行回归,但我无法用回归层替换最后一层分类
当我尝试初始化模型时出现此错误
class RegressionViT(nn.Module):
def __init__(self, in_features=224 * 224 * 3, num_classes=1, pretrained=True):
super(RegressionViT, self).__init__()
self.vit_b_16 = vit_b_16(pretrained=pretrained) # Load pre-trained weights
# Replace the final classification layer with a regression head
self.regressor = nn.Linear(self.vit_b_16.heads.in_features, num_classes)
def forward(self, x):
x = self.vit_b_16(x)
x = self.regressor(x)
return x
VisionTransformer
的源代码,您会在本节中注意到self.heads
是顺序层,而不是线性层。默认情况下,它仅包含与最终分类层相对应的单个层head
。要覆盖该层,您可以执行以下操作:
heads = self.vit_b_16.heads
heads.head = nn.Linear(heads.head.in_features, num_classes)