如何打字检查pytorch损失?

问题描述 投票:0回答:2

我正在尝试输入检查给定参数是默认火炬损失(标准)的事实并在 python 中进行优化,因此我确信我可以计算以下操作:

loss = criterion(y_pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()

最终,我只想检查一下:

assert isinstance(criterion, torch.loss)
assert isinstance(optimizer, torch.optimizer)

但我在火炬文档中找不到做到这一点的方法。

有什么想法吗?

python pytorch
2个回答
0
投票

损失来自

torch.nn.modules.loss
,它不是一个类。但我们可以检查损失的“路径”是否包含它,例如这样:

criterion = nn.NLLLoss()

def is_torch_loss(criterion) -> bool:
    type_ = str(type(criterion)).split("'")[1]
    parent = type_.rsplit(".", 1)[0]
    
    return parent == "torch.nn.modules.loss"


is_loss = is_torch_loss(criterion)

0
投票

您可以访问nn

import torch
import torch.nn as nn

criterion = nn.NLLLoss()

if isinstance(criterion, nn.modules.loss._Loss):
    print("The criterion is a PyTorch loss function.")
else:
    print("The criterion is not a PyTorch loss function.")
>>> criterion = nn.NLLLoss()
>>> isinstance(criterion, nn.modules.loss._Loss)
True
© www.soinside.com 2019 - 2024. All rights reserved.