PyTorch 的 MultiheadAttention 的运行时错误:如何解决形状不匹配问题?

问题描述 投票:0回答:1

我遇到了有关 PyTorch 的 MultiheadAttention 输入形状的问题。我已经初始化了 MultiheadAttention,如下所示:

attention = MultiheadAttention(embed_dim=1536, num_heads=4)

输入张量具有以下形状:

  • query.shapetorch.Size([1, 1, 1536])
  • key.shapevalue.shape都是torch.Size([1, 23, 1536])

但是,当尝试使用这些输入时,我遇到以下错误:

RuntimeError                              Traceback (most recent call last)
Cell In[15], line 1
----> 1 _ = cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)

File ~/main/reproduct/choi/make_embedding.py:384, in cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)
    381 print(embedding.shape)
    383 # attention
--> 384 output, attn_weights = attention(thumbnail, embedding, embedding)
    385 # attn_weight shape: (1, 1, j+1)
    387 attn_weights = attn_weights.squeeze(0).unsqueeze(-1)  # shape: (j+1, 1)

File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/activation.py:1205, in MultiheadAttention.forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)
   1191     attn_output, attn_output_weights = F.multi_head_attention_forward(
   1192         query, key, value, self.embed_dim, self.num_heads,
...
   5281     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
   5282     assert static_k.size(0) == bsz * num_heads, \
   5283         f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"

RuntimeError: shape '[1, 4, 384]' is invalid for input of size 35328

为什么我会遇到这个错误?

主要执行环境如下:

  • Ubuntu 20.04
  • 蟒蛇1.7.2
  • Python 3.8.5
  • VSCode 1.87.2
  • PyTorch 2.0.1

感谢您提前的合作。

pytorch multihead-attention
1个回答
0
投票

你需要改变

attention = MultiheadAttention(embed_dim=1536, num_heads=4)

attention = MultiheadAttention(embed_dim=1536, num_heads=4, batch_first=True)

batch_first=False
的默认行为是让计算认为您的查询批量大小与您的k/v批量大小不匹配。

© www.soinside.com 2019 - 2024. All rights reserved.