这段代码运行完美,但我想知道 my_forward 函数中的参数“x”指的是什么

问题描述 投票:0回答:1

参考 VIT 转换器示例中的注意力图:https://github.com/huggingface/pytorch-image-models/discussions/1232?sort=old

这段代码运行完美,但我想知道 my_forward 函数中的参数“x”指的是什么。以及 x 值在代码中如何以及在何处传递给函数 my_forward。

def my_forward(x):
        B, N, C = x.shape

        qkv = attn_obj.qkv(x).reshape(
            B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv.unbind(0) 
pytorch pytorch-lightning attention-model self-attention vision-transformer
1个回答
0
投票

这需要一些代码检查,但如果您查看正确的位置,您可以轻松找到实现。让我们从您的片段开始。

  • my_forward_wrapper
    函数是一个定义
    my_forward
    并返回它的函数生成器。此实现覆盖了加载模型
    blocks[-1].attn
    的最后一个块注意力层
    "deit_small_distilled_patch16_224"
    的实现。

    model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)
    
  • x
    对应的是前一个块的输出。要理解,您可以深入了解 timm 的源代码。脚本中加载的模型是
    deit_small_distilled_patch16_224
    ,它返回一个
    VisionTransformerDistilled
    实例。这些块在
    VisionTransformer
    类中定义。有
    n=depth
    块按顺序定义。默认的块定义由
    Block
    给出,其中attn由
    Attention
    实现,详细信息如下:

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x) \
                  .reshape(B, N, 3, self.num_heads, self.head_dim) \
                  .permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)
    
        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
    
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
© www.soinside.com 2019 - 2024. All rights reserved.