计算两个形状相同的张量的 MSE 损失

问题描述 投票:0回答:1

我正在尝试执行一项分类任务,但由于某些原因,我需要删除 softmax 并将损失模块从交叉熵替换为 MSE,以便为标签(目标)创建一个热张量,我执行以下操作:

        labels_onehot = nn.functional.one_hot(labels, num_classes=10).float()

但是当我尝试计算损失时,会抛出异常

Cell In[13], line 121
    116 print("Labels one-hot shape:", labels_onehot.shape)
    118 loss = criterion(outputs, labels_onehot)
--> 121 loss = criterion(outputs, labels)
    122 loss.backward()
    123 optimizer.step()

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/loss.py:535, in MSELoss.forward(self, input, target)
    534 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 535     return F.mse_loss(input, target, reduction=self.reduction)

File /opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:3328, in mse_loss(input, target, size_average, reduce, reduction)
   3325 if size_average is not None or reduce is not None:
   3326     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3328 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
   3329 return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))

File /opt/conda/lib/python3.10/site-packages/torch/functional.py:73, in broadcast_tensors(*tensors)
     71 if has_torch_function(tensors):
     72     return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 73 return _VF.broadcast_tensors(tensors)

RuntimeError: The size of tensor a (10) must match the size of tensor b (64) at non-singleton dimension 1 ```

I tried printing the shapes and they both were of the same shape and I cannot see why the exceptions were thrown.
python pytorch tensor loss mse
1个回答
0
投票

问题可能出在这一行:

loss = criterion(outputs, labels)

如果您对标签使用 one-hot 编码,则应将

labels
替换为
labels_onehot

loss = criterion(outputs, labels_onehot)
© www.soinside.com 2019 - 2024. All rights reserved.