如何在numpy中找到每行中最大的索引,行的串联?

问题描述 投票:0回答:1

我不知道这是否简单,或者是否先询问过。 (我搜索但没有找到正确的方法。我找到了numpy.argmaxnumpy.amax,但我无法正确使用它们。)

我有一个numpy数组(它是一个CxKxN矩阵)如下(C=K=N=3):

array([[[1, 2, 3],
        [2, 1, 4],
        [4, 3, 3]],

       [[2, 1, 1],
        [1, 3, 1],
        [3, 4, 2]],

       [[5, 2, 1],
        [3, 3, 3],
        [4, 1, 2]]])

我想找到每行最大元素的索引。一条线是每个矩阵的三个(C)行的串联。换句话说,i-th线是第一个矩阵中的i-row,第二个矩阵中的i-th行的串联,......,直到i-th矩阵中的C-th行。

例如,第一行是

[1, 2, 3, 2, 1, 1, 5, 2, 1]

所以我想回来

[2, 0, 0] # the index of the maximum in the first line

[0, 1, 2] # the index of the maximum in the second line

[0, 2, 0] # the index of the maximum in the third line

要么

[1, 2, 1] # the index of the maximum in the third line

要么

[2, 2, 0] # the index of the maximum in the third line

现在,我正在尝试这个

np.argmax(a[:,0,:], axis=None) # for the first line

它返回6

np.argmax(a[:,1,:], axis=None) 

它返回2

np.argmax(a[:,2,:], axis=None) 

它返回0

但我能够将这些数字转换为6 = (2,0,0)等指数。

python numpy
1个回答
1
投票

通过转置和重塑我得到你的'行'

In [367]: arr.transpose(1,0,2).reshape(3,9)
Out[367]: 
array([[1, 2, 3, 2, 1, 1, 5, 2, 1],
       [2, 1, 4, 1, 3, 1, 3, 3, 3],
       [4, 3, 3, 3, 4, 2, 4, 1, 2]])
In [368]: np.argmax(_, axis=1)
Out[368]: array([6, 2, 0])

这些最大值与您的相同。相同的索引,但在(3,3)数组中:

In [372]: np.unravel_index([6,2,0],(3,3))
Out[372]: (array([2, 0, 0]), array([0, 2, 0]))

加入中等维度范围:

In [373]: tup = (_[0],np.arange(3),_[1])
In [374]: np.transpose(tup)
Out[374]: 
array([[2, 0, 0],
       [0, 1, 2],
       [0, 2, 0]])
© www.soinside.com 2019 - 2024. All rights reserved.