Pytorch:在unet体系结构中使用自定义权重图的正确方法

问题描述 投票:3回答:1

[u-net体系结构中有一个著名的技巧,可以使用自定义权重图来提高准确性。下面是它的详细信息-

enter image description here

现在,通过在这里和其他多个地方询问,我了解了两种方法。我想知道哪种方法正确,或者还有其他正确的方法更正确吗?

1)首先是在训练循环中使用torch.nn.Functional方法-

loss = torch.nn.functional.cross_entropy(output, target, w)w将是计算得出的自定义重量。

2)其次是在训练循环之外的损失函数调用中使用reduction='none'criterion = torch.nn.CrossEntropy(reduction='none')

然后在训练循环中乘以自定义权重-

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

现在,我有点困惑,哪个是正确的,或者还有其他方法,或者两者都是正确的?

python pytorch image-segmentation semantic-segmentation
1个回答
0
投票

如果您查看文档,将会看到second方法实际上是first版本的推出版本。默认情况下,torch.nn.functional.cross_entropy将批次中每个损失要素的损失平均。这正是您在第二种方法中手动执行的操作。因此,这只是样式问题。我会在之间写一些东西:)

criterion = torch.nn.CrossEntropyLoss(weight=W)
...
loss = criterion(pd, gt)

我相信forward这样的功能将变得更加透明和易读。而且,如果您想尝试其他损失函数,则只需更改单行即可。

© www.soinside.com 2019 - 2024. All rights reserved.