如何在Pytorch中用scaled_dot_product_attention()替换这个简单的代码?

问题描述 投票:0回答:1

考虑来自 Crossformer 的 代码片段:

def forward(self, queries, keys, values):
    B, L, H, E = queries.shape
    _, S, _, D = values.shape
    scale = self.scale or 1./sqrt(E)

    scores = torch.einsum("blhe,bshe->bhls", queries, keys)
    A = self.dropout(torch.softmax(scale * scores, dim=-1))
    V = torch.einsum("bhls,bshd->blhd", A, values)
    
    return V.contiguous()

我正在尝试通过用 Flash Attention 替换简单的调用来加速它。为此,我做了以下事情:

def forward(self, queries, keys, values):
    # I'm not sure about the below - it's just a ChatGPT-assisted guess
    # B represents the batch size.
    # L is the sequence length for queries (or target sequence length).
    # H is the number of attention heads.
    # E is the depth (dimension) of each attention head for queries/keys.
    # S is the sequence length for keys/values (or source sequence length).
    # D is the depth (dimension) of each attention head for values.
    B, L, H, E = queries.shape
    _, S, _, D = values.shape

    y = torch.nn.functional.scaled_dot_product_attention(
        queries, keys, values, dropout_p=self.dropout_p if self.training else None)
    y = y.contiguous()
    return y

但是,使用上面的代码,我收到以下错误:

RuntimeError: The size of tensor a (10) must match the size of tensor b (4) at non-singleton dimension 1

调试器向我显示以下张量大小:

  • keys
    : (2048, 4, 16, 32)
  • queries
    : (2048, 10, 16, 32)
  • values
    : (2048, 4, 16, 32)

此更改中我缺少什么?

python deep-learning pytorch tensor attention-model
1个回答
0
投票

序列维度必须位于维度

-2
(请参阅文档)。

因此,在您的情况下,您必须将维度 1 与维度 2 转置:

forward(
   queries.transpose(1, 2),
   keys.transpose(1, 2),
   values.transpose(1, 2)
).transpose(1, 2)
© www.soinside.com 2019 - 2024. All rights reserved.