如何从 torchvision VisionTransfomer (ViT) 中提取特征?

问题描述 投票:0回答:0

为了将预训练的 VisionTransformer 中的特征用于下游任务,我想提取特征。我如何使用来自 torchvision 的 vit_b_16 提取特征?每个图像的输出应该是 768 维特征。

与使用 CNN 所做的类似,我只是试图删除输出层并将输入传递给其余层:

    from torch import nn

    from torchvision.models.vision_transformer import vit_b_16
    from torchvision.models import ViT_B_16_Weights
    
    from PIL import Image as PIL_Image

    vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    modules = list(vit.children())[:-1]
    feature_extractor = nn.Sequential(*modules)

    preprocessing = ViT_B_16_Weights.DEFAULT.transforms()

    img = PIL_Image.open("example.png")
    img = preprocessing(img)

    feature_extractor(img)

但这会导致一个例外:

RuntimeError: The size of tensor a (14) must match the size of tensor b (768) at non-singleton dimension 2
pytorch computer-vision feature-extraction transformer-model torchvision
© www.soinside.com 2019 - 2024. All rights reserved.