PyTorch中的weighted_cross_entropy_with_logits的模拟

问题描述 投票:0回答:1

我正在尝试用PyTorch训练模型。是否有任何简单的方法可以从Tensorflow创建像weighted_cross_entropy_with_logits这样的损失?

pos_weight中有weighted_cross_entropy_with_logits论证可以帮助平衡。但是BCEWithLogitsLoss中的参数列表中只有标签的权重。

tensorflow machine-learning pytorch
1个回答
2
投票

您可以根据需要编写自己的自定义丢失功能。例如,你可以写:

def weighted_cross_entropy_with_logits(logits, target, pos_weight):
    return targets * -logits.sigmoid().log() * pos_weight + 
               (1 - targets) * -(1 - logits.sigmoid()).log()

这是一个基本的实现。您应该按照here提到的步骤来确保稳定性并避免溢出。只需使用他们衍生的最终配方。

© www.soinside.com 2019 - 2024. All rights reserved.