我可以在 Pytorch 中使用 one-hot 编码输出进行分割,具有焦点和骰子损失吗?

问题描述 投票:0回答:1

我知道对于使用神经网络和交叉熵损失的分类,我们需要单热编码输出,但在 PyTorch 中,交叉熵损失不接受单热编码目标,我们应该直接在正常情况下给它标签格式。

现在,我想知道这对于图像分割任务是否相同,其中损失函数是骰子损失或焦点损失等。也就是说,如果我像 tensorflow 一样对目标掩码进行单热编码以进行分割是否可以,或者我不能做类似于 Pytorch 中的分类任务吗? (假设我将使用标准的 CNN,例如 3DUNet)

补充说明: 我已经使用神经网络对 Pytorch 进行了分类,交叉熵损失不能接受单热编码输入是没有意义的。现在,我希望对于用于分割的其他损失,这应该与交叉熵损失相同,但我不确定,因为它没有任何意义。

machine-learning deep-learning pytorch computer-vision image-segmentation
1个回答
0
投票

对于分类,

CrossEntropyLoss()
可以处理 one-hot 目标,如 doc

中所指定

这个标准期望的目标应该包含: [0,C) 范围内的类索引,其中 C 是类的数量;或每个班级的概率。

因此,对于您的 one-hot 目标,您有两个选择:

  1. 使用
    torch.argmax()
    将单热目标转换为类索引。
  2. 直接用作目标。

例如:

cross_entropy = nn.CrossEntropyLoss()
input = torch.randn(3, 5)
print(f"input: {input}")
# input: tensor([[-0.6498, -0.4508,  1.0618, -1.4337, -1.6479],
#             [-0.9778,  0.0141, -0.5646, -1.0664,  0.9022],
#             [ 0.0797,  0.7878,  0.6092, -0.2396, -0.5839]])

random_idx = torch.randint(0, 5, (3,))
print(f"random_idx: {random_idx}")
# random_idx: tensor([2, 3, 3])

target_one_hot = torch.eye(5)[random_idx]
output_one_hot = cross_entropy(input, target_one_hot)
print(f"target_one_hot: {target_one_hot}")
print(f"output_one_hot: {output_one_hot}")
# target_one_hot: tensor([[0., 0., 1., 0., 0.],
#                         [0., 0., 0., 1., 0.],
#                         [0., 0., 0., 1., 0.]])
# output_one_hot: 1.7241638898849487

target = target_one_hot.argmax(dim=1)
output = cross_entropy(input, target)
print(f"target: {target}")
print(f"output: {output}")
# target: tensor([2, 3, 3])
# output: 1.7241638898849487

对于分割,正如我上面所解释的,可以使用

CrossEntropyLoss()
来处理one-hot targets。对于
Dice
损失,由于 PyTorch 中没有官方实现,您需要自己实现它,因此您可以根据需要定义目标(one-hot 或其他)。

也许你可以查看这个 GitHub 代码,它实现了经典的 U-Net 并使用

CrossEntropyLoss
DiceLoss

训练模型
© www.soinside.com 2019 - 2024. All rights reserved.