Pytorch 操作检测 NaNs

问题描述 投票:0回答:5

是否有 Pytorch 内部程序来检测张量中的

NaN
? Tensorflow 有
tf.is_nan
tf.check_numerics
操作...Pytorch 是否有类似的东西?我在文档中找不到类似的东西......

我正在专门寻找 Pytorch 内部例程,因为我希望这种情况发生在 GPU 和 CPU 上。这不包括基于 numpy 的解决方案(如

np.isnan(sometensor.numpy()).any()
)...

python pytorch
5个回答
87
投票

您始终可以利用以下事实:

nan != nan

>>> x = torch.tensor([1, 2, np.nan])
tensor([  1.,   2., nan.])
>>> x != x
tensor([ 0,  0,  1], dtype=torch.uint8)

使用 pytorch 0.4 还有

torch.isnan
:

>>> torch.isnan(x)
tensor([ 0,  0,  1], dtype=torch.uint8)

44
投票

从 PyTorch 0.4.1 开始,有

detect_anomaly
上下文管理器,它会在反向传播的所有步骤之间自动插入相当于
assert not torch.isnan(grad).any()
的断言。当向后传递过程中出现问题时,它非常有用。


32
投票

正如 @cleros 在 @nemo 答案的评论中所建议的,您可以使用

any()
运算符将其作为布尔值获取:

from torch import Tensor, nan
torch.isnan(Tensor([nan])).any()
# tensor(True)

这适用于

torch.nan
np.nan

请注意,

torch.isnan
要求输入是张量。这会抛出一个
TypeError
torch.isnan(torch.nan)


8
投票

如果你想直接在张量上调用它:

import torch

x = torch.randn(5, 4)
print(x.isnan().any())

输出:

import torch
x = torch.randn(5, 4)
print(x.isnan().any())
tensor(False)

4
投票

如果有任何值为 nan,则为真:

torch.any(tensor.isnan())

如果全部为 nan,则为真:

torch.all(tensor.isnan())
© www.soinside.com 2019 - 2024. All rights reserved.