我有一些数据y_hat
看起来像:
[[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
...
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]]
我想获得每行的argmax
,以便我最终得到一个矢量:
[[3]
[8]
[8]
...
[5]
[1]
[7]]
如果我只是做np.argmax(y_hat)
,它会返回一个1
。
这是argmax
与numpy
广播之后的一种方式
a.argmax(axis = 1)[:,None]
要么
a[:,None].argmax(-1)
np.argmax
accepts an axis
keyword argument。用那个。
列为axis=0
,行为axis=1
。