我之所以做到这一点,是因为必须对变压器模型的长输入进行批处理,并注意到批处理和非批处理结果之间的差异。我终于隔离了我注意到的第一个差异,结果如下:
import torch
n = 20
vec = torch.rand(n, 20)
a = torch.rand(30, 20)
for i in range(1, n+1):
print(i, torch.equal(
torch.nn.functional.linear(vec, a)[:i],
torch.nn.functional.linear(vec[:i], a)))
产生输出:
1 False
2 False
3 False
4 True
5 True
6 True
7 False
8 False
9 False
10 True
11 True
12 True
13 False
14 False
15 False
16 True
17 True
18 True
19 True
20 True
这只是一个操作,当多次组合时(如在变压器中),可能会导致较大的分歧,扩大 torch.allclose 输出 True 的 atol。为什么会这样,有人能做点什么吗?
欢迎来到浮点运算的勇敢世界!
float
运算会引入舍入误差,而矩阵乘法会将它们累积到显着值。
https://pytorch.org/docs/stable/notes/numerical_accuracy.html
如果您避免不精确的舍入
vec = torch.floor (torch.rand(n, 20)*10)
a = torch.floor( torch.rand(30, 20)*10 )
你会得到所有
True
-s。
可能的解决方案是使用
torch.isclose
。