在 timm 中修改 PyTorch 中的自定义头的 Vision Transformer (ViT) 模型

问题描述 投票:0回答:1

我正在使用 PyTorch 和 timm 库处理 Vision Transformer (ViT) 模型。我的目标是修改 ViT 模型,将默认分类头替换为自定义头,该自定义头取所有标记的平均值并添加新的分类层。

timm中ViT模型的默认总结是这样结束的:

       LayerNorm-247             [-1, 197, 768]           1,536
        Identity-248                  [-1, 768]               0
         Dropout-249                  [-1, 768]               0
          Linear-250                 [-1, 1000]         769,000
VisionTransformer-251                 [-1, 1000]               0

要删除到目前为止我编码的最后一层:

class VisionTransformerWithoutHead(nn.Module):
    
    def __init__(self, model_name):
        super(VisionTransformerWithoutHead, self).__init__()

        # Load the ViT model
        vit_model = timm.create_model(model_name, pretrained=True)

        # Remove the final layers
        self.features = nn.Sequential(*list(vit_model.children())[:-1])

    def forward(self, x):
        # Forward pass through the modified model
        output = self.features(x)
        return output

总结现在结束于:

       LayerNorm-247             [-1, 196, 768]           1,536
        Identity-248             [-1, 196, 768]               0
         Dropout-249             [-1, 196, 768]               0

它将标记的数量从 197 减少到 196,并且似乎删除了类标记。我想了解为什么会发生这种情况,以及是否有一种方法可以在保留类标记的同时仅删除最后一层。

python pytorch computer-vision
1个回答
0
投票
import torch
import torch.nn as nn
import timm

class CustomVisionTransformer(nn.Module):
    def __init__(self, model_name, num_classes):
        super(CustomVisionTransformer, self).__init__()
        # Load the ViT model
        self.vit_model = timm.create_model(model_name, pretrained=True)
        
        # Remove the classifier head
        self.vit_model.head = nn.Identity()

        # Add a custom head
        # Assuming the dimension of the last layer's output is 768, adjust if necessary
        self.custom_head = nn.Linear(768, num_classes)

    def forward(self, x):
        # Forward pass through the base ViT model
        x = self.vit_model(x)

        # Take the mean of all tokens (including the class token)
        # You may need to adjust this depending on how you want to handle the class token
        x = torch.mean(x, dim=1)

        # Pass through the custom classification head
        x = self.custom_head(x)
        return x
© www.soinside.com 2019 - 2024. All rights reserved.