我正在使用 PyTorch 和 timm 库处理 Vision Transformer (ViT) 模型。我的目标是修改 ViT 模型,将默认分类头替换为自定义头,该自定义头取所有标记的平均值并添加新的分类层。
timm中ViT模型的默认总结是这样结束的:
LayerNorm-247 [-1, 197, 768] 1,536
Identity-248 [-1, 768] 0
Dropout-249 [-1, 768] 0
Linear-250 [-1, 1000] 769,000
VisionTransformer-251 [-1, 1000] 0
要删除到目前为止我编码的最后一层:
class VisionTransformerWithoutHead(nn.Module):
def __init__(self, model_name):
super(VisionTransformerWithoutHead, self).__init__()
# Load the ViT model
vit_model = timm.create_model(model_name, pretrained=True)
# Remove the final layers
self.features = nn.Sequential(*list(vit_model.children())[:-1])
def forward(self, x):
# Forward pass through the modified model
output = self.features(x)
return output
总结现在结束于:
LayerNorm-247 [-1, 196, 768] 1,536
Identity-248 [-1, 196, 768] 0
Dropout-249 [-1, 196, 768] 0
它将标记的数量从 197 减少到 196,并且似乎删除了类标记。我想了解为什么会发生这种情况,以及是否有一种方法可以在保留类标记的同时仅删除最后一层。
import torch
import torch.nn as nn
import timm
class CustomVisionTransformer(nn.Module):
def __init__(self, model_name, num_classes):
super(CustomVisionTransformer, self).__init__()
# Load the ViT model
self.vit_model = timm.create_model(model_name, pretrained=True)
# Remove the classifier head
self.vit_model.head = nn.Identity()
# Add a custom head
# Assuming the dimension of the last layer's output is 768, adjust if necessary
self.custom_head = nn.Linear(768, num_classes)
def forward(self, x):
# Forward pass through the base ViT model
x = self.vit_model(x)
# Take the mean of all tokens (including the class token)
# You may need to adjust this depending on how you want to handle the class token
x = torch.mean(x, dim=1)
# Pass through the custom classification head
x = self.custom_head(x)
return x