根据PyTorch文档,类
BCEWithLogitsLoss()
的优点是可以使用
用于数值稳定性的对数和指数技巧。
如果我们使用类
BCEWithLogitsLoss()
并将参数 reduction
设置为 None
,他们有一个公式:
我现在简化了条款,经过几行计算得到:
我很好奇这是否是源代码的做法,但我找不到它..他们唯一的代码是这样的:
nn.BCEWithLogitsLoss
实际上只是 sigmoid 函数内部的交叉熵损失。如果模型的输出层未使用 sigmoid 封装,则可以使用它。通常与单个输出层神经元的原始输出一起使用。
简单地说,你的模型的输出表示
pred
将是一个原始值。为了获得概率,您必须使用torch.sigmoid(pred)
。 (要获得实际的类标签,您需要 torch.round(torch.sigmoid(pred))
。)但是,当您使用 nn.BCEWithLogitsLoss
时,您不需要执行类似的操作(即采用 sigmoid)。在这里你只需要执行以下操作-
criterion = nn.BCEWithLogitsLoss()
loss = criterion(pred, target) # pred is just raw nn output
因此,进入实现部分,criterion 接受两个 torch 张量 - 一个是原始 nn 输出,另一个是真实的类标签,然后使用 sigmoid 包装第一个张量 - 对于张量中的每个元素,然后计算交叉熵损失
(-(target*log(sigmoid(pred)))
对于每一对并将其简化为平均值。
pytorch的所有功能代码都是用C++实现的。实现的源代码位于here。
pytorch 实现将
BCEWithLogitsLoss
计算为
其中
t_n
就是 -relu(x)
。这里使用 t_n
基本上是一种避免取正值指数的巧妙方法(从而避免溢出)。通过将 t_n
替换为 l_n
可以使这一点变得更清楚,从而产生以下等效表达式
这是一篇用Python实现的Medium文章:
https://medium.com/@sahilcarterr/why-nn-bcewithlogitsloss-numerically-stable-6a04f3052967