参考 VIT 转换器示例中的注意力图:https://github.com/huggingface/pytorch-image-models/discussions/1232?sort=old
这段代码运行完美,但我想知道 my_forward 函数中的参数“x”指的是什么。以及 x 值在代码中如何以及在何处传递给函数 my_forward。
def my_forward(x):
B, N, C = x.shape
qkv = attn_obj.qkv(x).reshape(
B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0)
这需要一些代码检查,但如果您查看正确的位置,您可以轻松找到实现。让我们从您的片段开始。
my_forward_wrapper
函数是一个定义my_forward
并返回它的函数生成器。此实现覆盖了加载模型 blocks[-1].attn
的最后一个块注意力层 "deit_small_distilled_patch16_224"
的实现。
model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)
x
对应的是前一个块的输出。要理解,您可以深入了解 timm 的源代码。脚本中加载的模型是 deit_small_distilled_patch16_224
,它返回一个 VisionTransformerDistilled
实例。这些块在 VisionTransformer
类中定义。有 n=depth
块按顺序定义。默认的块定义由Block
给出,其中attn由Attention
实现,详细信息如下:
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x) \
.reshape(B, N, 3, self.num_heads, self.head_dim) \
.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x