我正在计算我的准确性,就像
(outputs.round() == targets).all(dim=2).all(dim=1).sum().item() / outputs.shape[0]
其中
outputs
和 targets
的形状为 NxAxB
。
N
是批量大小。
剩下的部分是预测/真值,我想看看它们是否相同。
目前我正在使用
.all(dim=2).all(dim=1)
。
现在的问题是,如果我有不同的模型,形状就会不同。它们将是 NxA
,所以我当前的方法不起作用,因为 dim=2
不存在。
(outputs.round() == targets).all(dim=1).sum().item() / outputs.shape[0]
,可以工作,但又只适用于第二个模型。
理想情况下,我想将
.all
应用于除第一个维度(批量维度)之外的所有内容。
我该怎么做?
要推广到任意数量的维度,您可以使用
dim=1
将布尔张量从
torch.flatten
向外展平,然后应用 all
和 mean
:
>>> (outputs.round() == targets).flatten(1).all(1).float().mean()
注意:
torch.flatten(dim=1)
会将张量从dim=1
展平到dim=-1
。