将 torch.all 应用于除第一个维度之外的每个维度

问题描述 投票:0回答:0

我正在计算我的准确性,就像

(outputs.round() == targets).all(dim=2).all(dim=1).sum().item() / outputs.shape[0]

其中

outputs
targets
的形状为
NxAxB
N
是批量大小。 剩下的部分是预测/真值,我想看看它们是否相同。

目前我正在使用

.all(dim=2).all(dim=1)
。 现在的问题是,如果我有不同的模型,形状就会不同。它们将是
NxA
,所以我当前的方法不起作用,因为
dim=2
不存在。

(outputs.round() == targets).all(dim=1).sum().item() / outputs.shape[0]

,可以工作,但又只适用于第二个模型。

理想情况下,我想将

.all
应用于除第一个维度(批量维度)之外的所有内容。 我该怎么做?

python pytorch tensor
© www.soinside.com 2019 - 2024. All rights reserved.