Pytorch TypeError - eq()收到一个无效的参数组合。

问题描述 投票:0回答:1

我正在用BERT处理一个文本分类问题。当在本地机器上进行训练时,一切都很正常,但是当切换到服务器上时,我得到了以下错误。

<ipython-input-28-508d35ac5f5f> in flat_accuracy(preds, labels)
      5     pred_flat = np.argmax(preds, axis=1).flatten()
      6     labels_flat = labels.flatten()
----> 7     return np.sum(pred_flat == labels_flat) / len(labels_flat)
      8 
      9 # Function to calculate the f1_score of our predictions vs labels

TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
 * (Tensor other)
      didn't match because some of the arguments have invalid types: (numpy.ndarray)
 * (Number other)
      didn't match because some of the arguments have invalid types: (numpy.ndarray)

代码:

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

代码: 本地机器上的Torch版本: 1.4.0

Torch版本在服务器上 1.3.1

任何帮助将是非常感激的!

numpy machine-learning pytorch tensor
1个回答
0
投票

可能是 eq 在你的服务器上,Torch版本的实现不再允许你在一个 torch.Tensor 和a np.ndarray. 你应该胁迫任一 pred_flat 拟做 torch.Tensor或胁迫 labels_flat 是一个numpy数组。因为你使用的是 np.sum 中的返回语句,而你只是返回一个标量值,我就会把所有的东西都移到numpy中,所以

labels_flat = labels.numpy()

但如果你在GPU上,你可能需要调用 labels.cpu().numpy()如果你要追踪标签上的渐变,你可能需要 labels.detach().cpu().numpy().

© www.soinside.com 2019 - 2024. All rights reserved.