我正在计算我的准确性,就像
(outputs.round() == targets).all(dim=2).all(dim=1).sum().item() / outputs.shape[0]
其中
outputs
和 targets
的形状为 NxAxB
。
N
是批量大小。
剩下的部分是预测/真值,我想看看它们是否相同。
目前我正在使用
.all(dim=2).all(dim=1)
。
现在的问题是,如果我有不同的模型,形状就会不同。它们将是 NxA
,所以我当前的方法不起作用,因为 dim=2
不存在。
(outputs.round() == targets).all(dim=1).sum().item() / outputs.shape[0]
,可以工作,但又只适用于第二个模型。
理想情况下,我想将
.all
应用于除第一个维度(批量维度)之外的所有内容。
我该怎么做?