我正在用BERT处理一个文本分类问题。当在本地机器上进行训练时,一切都很正常,但是当切换到服务器上时,我得到了以下错误。
<ipython-input-28-508d35ac5f5f> in flat_accuracy(preds, labels)
5 pred_flat = np.argmax(preds, axis=1).flatten()
6 labels_flat = labels.flatten()
----> 7 return np.sum(pred_flat == labels_flat) / len(labels_flat)
8
9 # Function to calculate the f1_score of our predictions vs labels
TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
* (Tensor other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)
* (Number other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)
代码:
def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
代码: 本地机器上的Torch版本: 1.4.0
Torch版本在服务器上 1.3.1
任何帮助将是非常感激的!
可能是 eq
在你的服务器上,Torch版本的实现不再允许你在一个 torch.Tensor
和a np.ndarray
. 你应该胁迫任一 pred_flat
拟做 torch.Tensor
或胁迫 labels_flat
是一个numpy数组。因为你使用的是 np.sum
中的返回语句,而你只是返回一个标量值,我就会把所有的东西都移到numpy中,所以
labels_flat = labels.numpy()
但如果你在GPU上,你可能需要调用 labels.cpu().numpy()
如果你要追踪标签上的渐变,你可能需要 labels.detach().cpu().numpy()
.