为什么我收到 IndexError

问题描述 投票:0回答:1

为什么要使用以下代码:

import torch
import torch.nn as nn

input_tensor = torch.tensor([[2.0]])
target_tensor = torch.tensor([0], dtype=torch.long)
loss_function = nn.CrossEntropyLoss()
loss = loss_function(input_tensor, target_tensor)
print(loss)


input_tensor = torch.tensor([[2.0]])
target_tensor = torch.tensor([1], dtype=torch.long)
loss_function = nn.CrossEntropyLoss()
loss = loss_function(input_tensor, target_tensor)
print(loss)

显示以下错误:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[214], line 14
     12 target_tensor = torch.tensor([1], dtype=torch.long)
     13 loss_function = nn.CrossEntropyLoss()
---> 14 loss = loss_function(input_tensor, target_tensor)
     15 print(loss)

File ~/jupyter-env-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter-env-3.10/lib/python3.10/site-packages/torch/nn/modules/loss.py:1163, in CrossEntropyLoss.forward(self, input, target)
   1162 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1163     return F.cross_entropy(input, target, weight=self.weight,
   1164                            ignore_index=self.ignore_index, reduction=self.reduction,
   1165                            label_smoothing=self.label_smoothing)

File ~/jupyter-env-3.10/lib/python3.10/site-packages/torch/nn/functional.py:2996, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2994 if size_average is not None or reduce is not None:
   2995     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2996 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

IndexError: Target 1 is out of bounds.
python pytorch loss-function
1个回答
0
投票

这是因为

input_tensor
是类别概率。如果目标类别是
1
,则
input_tensor
的长度应为 2(类别 0 的概率和类别 1 的概率)

import torch
import torch.nn as nn

input_class_probs = torch.tensor([[2.0]])
target_tensor = torch.tensor([0], dtype=torch.long)
loss_function = nn.CrossEntropyLoss()
loss = loss_function(input_class_probs, target_tensor)
print(loss)


input_class_probs = torch.tensor([[2.0, 3.3]])
target_tensor = torch.tensor([1], dtype=torch.long)
loss_function = nn.CrossEntropyLoss()
loss = loss_function(input_class_probs, target_tensor)
print(loss)
© www.soinside.com 2019 - 2024. All rights reserved.