我正在尝试输入检查给定参数是默认火炬损失(标准)的事实并在 python 中进行优化,因此我确信我可以计算以下操作:
loss = criterion(y_pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
最终,我只想检查一下:
assert isinstance(criterion, torch.loss)
assert isinstance(optimizer, torch.optimizer)
但我在火炬文档中找不到做到这一点的方法。
有什么想法吗?
损失来自
torch.nn.modules.loss
,它不是一个类。但我们可以检查损失的“路径”是否包含它,例如这样:
criterion = nn.NLLLoss()
def is_torch_loss(criterion) -> bool:
type_ = str(type(criterion)).split("'")[1]
parent = type_.rsplit(".", 1)[0]
return parent == "torch.nn.modules.loss"
is_loss = is_torch_loss(criterion)
您可以访问nn
import torch
import torch.nn as nn
criterion = nn.NLLLoss()
if isinstance(criterion, nn.modules.loss._Loss):
print("The criterion is a PyTorch loss function.")
else:
print("The criterion is not a PyTorch loss function.")
>>> criterion = nn.NLLLoss()
>>> isinstance(criterion, nn.modules.loss._Loss)
True