我遇到了有关 PyTorch 的 MultiheadAttention 输入形状的问题。我已经初始化了 MultiheadAttention,如下所示:
attention = MultiheadAttention(embed_dim=1536, num_heads=4)
输入张量具有以下形状:
但是,当尝试使用这些输入时,我遇到以下错误:
RuntimeError Traceback (most recent call last)
Cell In[15], line 1
----> 1 _ = cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)
File ~/main/reproduct/choi/make_embedding.py:384, in cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)
381 print(embedding.shape)
383 # attention
--> 384 output, attn_weights = attention(thumbnail, embedding, embedding)
385 # attn_weight shape: (1, 1, j+1)
387 attn_weights = attn_weights.squeeze(0).unsqueeze(-1) # shape: (j+1, 1)
File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/activation.py:1205, in MultiheadAttention.forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)
1191 attn_output, attn_output_weights = F.multi_head_attention_forward(
1192 query, key, value, self.embed_dim, self.num_heads,
...
5281 # TODO finish disentangling control flow so we don't do in-projections when statics are passed
5282 assert static_k.size(0) == bsz * num_heads, \
5283 f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
RuntimeError: shape '[1, 4, 384]' is invalid for input of size 35328
为什么我会遇到这个错误?
主要执行环境如下:
感谢您提前的合作。
你需要改变
attention = MultiheadAttention(embed_dim=1536, num_heads=4)
到
attention = MultiheadAttention(embed_dim=1536, num_heads=4, batch_first=True)
batch_first=False
的默认行为是让计算认为您的查询批量大小与您的k/v批量大小不匹配。