Transformer 语言模型中的关键矩阵冗余?

问题描述 投票:0回答:1

Transformer 语言模型的简单实现,例如这个,定义了 3 个矩阵 K、Q、V 来计算键、查询和值。然而,矩阵 K 和 Q 永远不会单独使用:所有 Transformer 计算都会形成它们的乘积

Q^t K
。所以我想知道为什么不直接学习这个乘积矩阵而不是把它分成2个矩阵K和Q。

部分答案可能来自 K 和 Q 的大小,即

d -> n
,其中 d 是 token 嵌入的维度,n 是键和查询的维度。
Q^t K
的大小为
d -> d
。因此,分别学习 K 和 Q 意味着优化
2*n*d
参数,而学习乘积
Q^t K
就是
d*d
参数。我看到的唯一有用的分割是当
n <= d/2
时,因为需要优化的参数较少。但在极限情况
n = d/2
下,乘积矩阵
Q^t K
的秩是
d/2
,这是非常退化的。使用相同数量的参数
d^2
,我们可以学习一个无约束的方阵。这可能会在训练数据中学习更灵活和微妙的模式。

Attention is all you need 论文基础模型第 9 页中,我们看到 d = 512 和 n = 64,因此乘积矩阵

Q^t K
确实具有退化秩。减少参数数量是这里的真实且独特的意图吗?这些退化等级有助于自然语言处理是否有理论上的依据?

nlp transformer-model
1个回答
0
投票

2 * n * d 与 d * d 结构与 LORA 的工作方式非常相似,或者实际上是 WALS 模型。所以只要 2nd < d*d, this might be a good way of saving parameters.

© www.soinside.com 2019 - 2024. All rights reserved.