我们已经根据教程实现了一个变压器这里。
我们需要访问查询、键和值矩阵的权重,并计划使用
model.state_dict()
来完成此操作。然而,模型将这些矩阵存储为这个共享矩阵中的串联。
model.state_dict()['transformer_encoder.layers.0.self_attn.in_proj_weight']
我们假设它们按照查询、键、值的顺序连接起来。如果是这样,我们可以手动分割张量。但是,我们无法在 PyTorch 文档中验证这是否是实际的顺序。有没有简单的方法来验证是否是这种情况?或者任何其他方式来单独获取此变压器模型的查询、键和值矩阵?
Pytorch 代码库中 MultiHeadAttention 的实现遵循简单的 check:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs), requires_grad = not self.freeze_proj_mat['q'])
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs), requires_grad = not self.freeze_proj_mat['k'])
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs), requires_grad = not self.freeze_proj_mat['v'])
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
哪里,
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
这里,
kdim, embed_dim, vdim
根据函数定义都有其通常的含义,请检查这里。
这是从用户那里抽象出来的实现细节。但正如您所提到的,要在
Q, K, V
为 self._qkv_same_embed_dim
时访问 True
矩阵,您可以提取此张量并调用 _in_projection_packed
API 源中提供的方法
nn.functional
。
您可以查看所有提供的这些功能实现的链接以供参考。
torch.split
函数将投影权重拆分为查询矩阵、键矩阵和值矩阵。像这样,
in_proj_weight = model.state_dict()['transformer_encoder.layers.0.self_attn.in_proj_weight']
q, k, v = torch.split(in_proj_weight, [embed_dim, embed_dim, embed_dim])
希望这有帮助。
是的,按照其他答案所述进行拆分,但您还需要转置生成的矩阵。换句话说,这就是您手动重现 MultiheadAttention 计算的方式:
# create a tensor to test on
l = torch.LongTensor([[0, 1, 2, 3], [4, 5, 6, 7]])
emb = torch.nn.Embedding(16, 24)
t = emb(l)
# create the attention with 3 heads of 8 dimensions each
attn = torch.nn.MultiheadAttention(3*8, 3, 0.0, batch_first=True)
o, a = attn(t, t, t, need_weights=True, average_attn_weights=False)
# manually calculate the attention to compare
q_w, k_w, v_w = torch.split(attn.in_proj_weight, [24, 24, 24])
q_b, k_b, v_b = torch.split(attn.in_proj_bias, [24, 24, 24])
q_w = q_w.T
k_w = k_w.T
v_w = v_w.T
qp = t @ q_w + q_b
kp = t @ k_w + k_b
vp = t @ v_w + v_b
qpp = qp.reshape(2, 4, 3, 8).transpose(1, 2)
kpp = kp.reshape(2, 4, 3, 8).transpose(1, 2)
vpp = vp.reshape(2, 4, 3, 8).transpose(1, 2)
ap = torch.nn.functional.softmax(qpp @ kpp.transpose(-1, -2) / math.sqrt(8), dim=-1)
pp = ap @ vpp
pp = pp.transpose(1, 2).reshape(2, 4, 24)
op = attn.out_proj(pp)
assert torch.allclose(o, op)
assert torch.allclose(a, ap)
print("equal!")