使用 self_attn.in_proj_weight 从 PyTorch 获取查询、键和值矩阵

问题描述 投票:0回答:2

我们已经根据教程实现了一个变压器这里

我们需要访问查询、键和值矩阵的权重,并计划使用

model.state_dict()
来完成此操作。然而,模型将这些矩阵存储为这个共享矩阵中的串联。

model.state_dict()['transformer_encoder.layers.0.self_attn.in_proj_weight']

我们假设它们按照查询、键、值的顺序连接起来。如果是这样,我们可以手动分割张量。但是,我们无法在 PyTorch 文档中验证这是否是实际的顺序。有没有简单的方法来验证是否是这种情况?或者任何其他方式来单独获取此变压器模型的查询、键和值矩阵?

pytorch tensor transformer-model
2个回答
0
投票

Pytorch 代码库中 MultiHeadAttention 的实现遵循简单的 check:

if not self._qkv_same_embed_dim:
            self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs), requires_grad = not self.freeze_proj_mat['q'])
            self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs), requires_grad = not self.freeze_proj_mat['k'])
            self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs), requires_grad = not self.freeze_proj_mat['v'])
            self.register_parameter('in_proj_weight', None)
else:
            self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

哪里,

self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

这里,

kdim, embed_dim, vdim
根据函数定义都有其通常的含义,请检查这里

这是从用户那里抽象出来的实现细节。但正如您所提到的,要在

Q, K, V
self._qkv_same_embed_dim
时访问
True
矩阵,您可以提取此张量并调用
_in_projection_packed
API 源
中提供的方法 nn.functional

您可以查看所有提供的这些功能实现的链接以供参考。


TLDR

您可以使用

torch.split
函数将投影权重拆分为查询矩阵、键矩阵和值矩阵。像这样,

in_proj_weight = model.state_dict()['transformer_encoder.layers.0.self_attn.in_proj_weight']
q, k, v = torch.split(in_proj_weight, [embed_dim, embed_dim, embed_dim])

希望这有帮助。


0
投票

是的,按照其他答案所述进行拆分,但您还需要转置生成的矩阵。换句话说,这就是您手动重现 MultiheadAttention 计算的方式:

# create a tensor to test on
l = torch.LongTensor([[0, 1, 2, 3], [4, 5, 6, 7]])
emb = torch.nn.Embedding(16, 24)
t = emb(l)

# create the attention with 3 heads of 8 dimensions each
attn = torch.nn.MultiheadAttention(3*8, 3, 0.0, batch_first=True)
o, a = attn(t, t, t, need_weights=True, average_attn_weights=False)

# manually calculate the attention to compare
q_w, k_w, v_w = torch.split(attn.in_proj_weight, [24, 24, 24])
q_b, k_b, v_b = torch.split(attn.in_proj_bias, [24, 24, 24])
q_w = q_w.T
k_w = k_w.T
v_w = v_w.T
qp = t @ q_w + q_b
kp = t @ k_w + k_b
vp = t @ v_w + v_b
qpp = qp.reshape(2, 4, 3, 8).transpose(1, 2)
kpp = kp.reshape(2, 4, 3, 8).transpose(1, 2)
vpp = vp.reshape(2, 4, 3, 8).transpose(1, 2)
ap = torch.nn.functional.softmax(qpp @ kpp.transpose(-1, -2) / math.sqrt(8), dim=-1)
pp = ap @ vpp
pp = pp.transpose(1, 2).reshape(2, 4, 24)
op = attn.out_proj(pp)

assert torch.allclose(o, op)
assert torch.allclose(a, ap)
print("equal!")
© www.soinside.com 2019 - 2024. All rights reserved.