PyTorch 矩阵乘法不考虑切片

问题描述 投票:0回答:1

我之所以做到这一点,是因为必须对变压器模型的长输入进行批处理,并注意到批处理和非批处理结果之间的差异。我终于隔离了我注意到的第一个差异,结果如下:

import torch

n = 20

vec = torch.rand(n, 20)
a = torch.rand(30, 20)

for i in range(1, n+1):
    print(i, torch.equal(
        torch.nn.functional.linear(vec, a)[:i],
        torch.nn.functional.linear(vec[:i], a)))

产生输出:

1 False
2 False
3 False
4 True
5 True
6 True
7 False
8 False
9 False
10 True
11 True
12 True
13 False
14 False
15 False
16 True
17 True
18 True
19 True
20 True

这只是一个操作,当多次组合时(如在变压器中),可能会导致较大的分歧,扩大 torch.allclose 输出 True 的 atol。为什么会这样,有人能做点什么吗?

python pytorch precision huggingface-transformers torch
1个回答
0
投票

欢迎来到浮点运算的勇敢世界!

float
运算会引入舍入误差,而矩阵乘法会将它们累积到显着值。 https://pytorch.org/docs/stable/notes/numerical_accuracy.html 如果您避免不精确的舍入

vec = torch.floor (torch.rand(n, 20)*10)
a = torch.floor( torch.rand(30, 20)*10 )

你会得到所有

True
-s。

可能的解决方案是使用

torch.isclose

© www.soinside.com 2019 - 2024. All rights reserved.