我正在尝试运行序列到序列分类的代码,但遇到了一个名为
0D or 1D target tensor expected, multi-target not supported on loss = criterion(outputs, labels)
的错误,其中我的标签形状为 torch.Size([16, 323]) torch.int64
。我的标准是nn.CrossEntropyLoss()
。
我已经检查了我的列车和标签样本;它们都是 323 大小,与我的火车集令牌的最大长度相匹配。附加值是通过填充添加的,并且对于填充所在的位置,标签的值为 -100。
只是一个总结:我正在做序列到序列的分类。我有两个列表,一个包含句子列表,每个实例都被标记化,另一个包含与句子中每个单词相对应的标签列表。 例如:
B-O:既不是缩写也不是长形式的标记。
B-AC:标记为缩写/首字母缩略词的标记。
B-LF:长格式开头的标记。
I-LF:长格式中的标记。
到目前为止,我已经完成了一些数据预处理,其中包括旅鼠和小写字母。之后,我填充每个句子和标签以匹配训练的最大长度。然后,我使用 word2vec 将它们格式化为数字格式。然后,我将它们转换为张量,为我的双向 lstm 模型做好准备,但我得到了错误。
for epoch in range(num_epochs):
total_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
#inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = bilstm_model(inputs)
# Ensure labels are in the correct format (single-dimensional class indices)
labels = labels.squeeze(dim=1) # Remove any unnecessary dimensions
print(labels.shape)
print(labels.dtype)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
# Calculate average loss for the epoch
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")
# Validation after each epoch
bilstm_model.eval() # Set model to evaluation mode
with torch.no_grad():
val_outputs = bilstm_model(val_embeddings_tensor)
val_loss = criterion(val_outputs, val_labels_tensor)
# Calculate accuracy for validation
_, predicted_labels = torch.max(val_outputs, 1)
correct_predictions = (predicted_labels == val_labels_tensor).sum().item()
val_accuracy = correct_predictions / len(val_labels_tensor)
print(f"Validation Accuracy: {val_accuracy:.4f}")
bilstm_model.train() # Set model back to training mode
上面的代码是模型训练。
我尝试将所有内容都放在 GPU 上,认为这就是问题所在,并尝试在互联网上查找解决方案,但一无所获。
您误解了交叉熵损失的计算方式。 之前的问题应该可以澄清您的理解。
如果我理解正确,您的模型应该输出形状为 [批量大小、标记数量、可能的类别] 的 logits,并且您的标签应为形状 [批量大小、标记数量],其中每个值对应于热中的目标索引- 标签的编码版本。
如果您的输入只应产生单个类输出,则要容易得多:再次使用标签的热编码版本的目标索引输出[批量大小,可能的类]的形状和[批量大小]的标签形状。
请注意,模型的输出不应经过激活函数,而应采用原始 logits。此外,请注意 CrossEntropyLoss 期望标签具有 Long dtype (torch.int64)。