考虑来自 Crossformer 的 代码片段:
def forward(self, queries, keys, values):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1./sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
return V.contiguous()
我正在尝试通过用 Flash Attention 替换简单的调用来加速它。为此,我做了以下事情:
def forward(self, queries, keys, values):
# I'm not sure about the below - it's just a ChatGPT-assisted guess
# B represents the batch size.
# L is the sequence length for queries (or target sequence length).
# H is the number of attention heads.
# E is the depth (dimension) of each attention head for queries/keys.
# S is the sequence length for keys/values (or source sequence length).
# D is the depth (dimension) of each attention head for values.
B, L, H, E = queries.shape
_, S, _, D = values.shape
y = torch.nn.functional.scaled_dot_product_attention(
queries, keys, values, dropout_p=self.dropout_p if self.training else None)
y = y.contiguous()
return y
但是,使用上面的代码,我收到以下错误:
RuntimeError: The size of tensor a (10) must match the size of tensor b (4) at non-singleton dimension 1
。
调试器向我显示以下张量大小:
keys
: (2048, 4, 16, 32)queries
: (2048, 10, 16, 32)values
: (2048, 4, 16, 32)此更改中我缺少什么?
序列维度必须位于维度
-2
(请参阅文档)。
因此,在您的情况下,您必须将维度 1 与维度 2 转置:
forward(
queries.transpose(1, 2),
keys.transpose(1, 2),
values.transpose(1, 2)
).transpose(1, 2)