如何找到张量对象中每行的最大索引?

问题描述 投票:0回答:1

所以我正在创建一个pytorch模型,对于前向传递,我正在应用我的前向传递方法来获得包含每个类的预测分数的分数张量。该张量的形状为[100,10]。现在,我希望通过将其与包含实际分数的y进行比较来获得准确性。该张量具有形状[100]。要比较两个我将使用torch.mean(scores == y),我会计算有多少是相同的。

问题是我需要转换得分张量,以便每行只包含每行中最高值的索引。例如,如果张量看起来像这样,

tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],

    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],

    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

然后我希望它被转换,使它看起来像这样。

tensor([5, 4, 0])

我怎么能这样做?

python python-3.x tensorflow pytorch
1个回答
0
投票

使用argmax与所需的dim(a.k.a。轴)

a = tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],
    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],
    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

a.argmax(1)
# tensor([ 5,  4,  0])
© www.soinside.com 2019 - 2024. All rights reserved.