Transformer 语言模型的简单实现,例如这个,定义了 3 个矩阵 K、Q、V 来计算键、查询和值。然而,矩阵 K 和 Q 永远不会单独使用:所有 Transformer 计算都会形成它们的乘积
Q^t K
。所以我想知道为什么不直接学习这个乘积矩阵而不是把它分成2个矩阵K和Q。
部分答案可能来自 K 和 Q 的大小,即
d -> n
,其中 d 是 token 嵌入的维度,n 是键和查询的维度。 Q^t K
的大小为d -> d
。因此,分别学习 K 和 Q 意味着优化 2*n*d
参数,而学习乘积 Q^t K
就是 d*d
参数。我看到的唯一有用的分割是当 n <= d/2
时,因为需要优化的参数较少。但在极限情况n = d/2
下,乘积矩阵Q^t K
的秩是d/2
,这是非常退化的。使用相同数量的参数d^2
,我们可以学习一个无约束的方阵。这可能会在训练数据中学习更灵活和微妙的模式。
在 Attention is all you need 论文基础模型第 9 页中,我们看到 d = 512 和 n = 64,因此乘积矩阵
Q^t K
确实具有退化秩。减少参数数量是这里的真实且独特的意图吗?这些退化等级有助于自然语言处理是否有理论上的依据?
2 * n * d 与 d * d 结构与 LORA 的工作方式非常相似,或者实际上是 WALS 模型。所以只要 2nd < d*d, this might be a good way of saving parameters.