我有以下模型,它构成了我的整个模型管道中的步骤之一:
import torch
import torch.nn as nn
class NPB(nn.Module):
def __init__(self, d, nhead, num_layers, dropout=0.1):
super(NPB, self).__init__()
self.te = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
num_layers=num_layers,
)
self.t_emb = nn.Parameter(torch.randn(1, d))
self.L = nn.Parameter(torch.randn(1, d))
self.td = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d, nhead=nhead, dropout=dropout, batch_first=True),
num_layers=num_layers,
)
self.ffn = nn.Linear(d, 6)
def forward(self, t_v, t_i):
print("--------------- t_v, t_i -----------------")
print('t_v: ', tuple(t_v.shape))
print('t_i: ', tuple(t_i.shape))
print("--------------- t_v + t_i + t_emb -----------------")
_x = t_v + t_i + self.t_emb
print(tuple(_x.shape))
print("--------------- te ---------------")
_x = self.te(_x)
print(tuple(_x.shape))
print("--------------- td ---------------")
_x = self.td(self.L, _x)
print(tuple(_x.shape))
print("--------------- ffn ---------------")
_x = self.ffn(_x)
print(tuple(_x.shape))
return _x
这里
t_v
和 t_i
是来自早期编码器块的输入。我将它们作为 (4,256)
的形状传递,其中 256
是特征数量,4
是批量大小。 t_emb
是时间嵌入。 L
表示学习矩阵,表示查询的嵌入。我用以下代码测试了这个模块块:
t_v = torch.randn((4,256))
t_i = torch.randn((4,256))
npb = NPB(d=256, nhead=8, num_layers=2)
npb(t_v, t_i)
输出:
=============== NPB ===============
--------------- t_v, t_i -----------------
t_v: (4, 256)
t_i: (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(1, 256)
--------------- ffn ---------------
(1, 6)
我期望输出的形状应为
(4,6)
,批量大小为 6
的每个样本有 6 个值。但输出的大小是(1,6)
。经过大量调整后,我尝试将 t_emb
和 L
形状从 (1,d)
更改为 (4,d)
,因为我不希望所有采样共享这些变量(通过广播:
self.t_emb = nn.Parameter(torch.randn(4, d)) # [n, d] = [4, 256]
self.L = nn.Parameter(torch.randn(4, d))
这给出了所需的形状输出 (4,6:
--------------- t_v, t_i -----------------
t_v: (4, 256)
t_i: (4, 256)
--------------- t_v + t_i + t_emb -----------------
(4, 256)
--------------- te ---------------
(4, 256)
--------------- td ---------------
(4, 256)
--------------- ffn ---------------
(4, 6)
我有以下疑问:
Q1. 到底为什么将
L
和 t_emb
形状从 (1,d)
更改为 (4,d)
有效?为什么不能通过广播与(1,d)
合作?查看文档 - transformer 类、transformer 解码器
对于未批处理(2 暗淡)输入,其中
src = (S, E)
和 tgt = (T, E)
,输出将具有形状 (T, E)
。
在 Transformer 解码器层中,第一个参数是
tgt
,它定义输出大小。
由于您将
tgt
参数 L
定义为 torch.randn(1, d)
,因此您的变压器解码器输出的大小将为 (1, d)
。
这与广播无关,这只是变压器层的输入/输出机制。