Argmax over vectors in numpy

问题描述 投票:0回答:1

我正在尝试编写尽可能干净高效的表达式,以解决以下问题。我有一组

[[v11,...,v1n],...,[vm1,...,vmn]]
向量(存储为
(m,n,a)
数组)和第二组
[b1,...,bk]
(同样,作为
(k,a)
数组)。我想要的是所有
argmax_j dot(vij, bx)
i,x
。我现在有以下内容,它会创建不必要的大数组,然后我必须采用对角线。

gamma = [[v11,...,v1n],...,[vm1,...,vmn]]  # coordinates (i,j,s)
b = [b1,...,bk]  # coordinates (k,s)

# compute argmaxes
argmaxes = (
    np.tensordot(gamma, b, axes=[[2], [1]])  # coordinates (i,j,k)
    .argmax(axis=1)  # coordinates (i,k)
)

result = np.diagonal(
    np.take(gamma, argmaxes, axis=1),  # coordinates (i, i, k, s)
    axis1=0, axis2=1
)  # coordinates (i, k, s)

有没有办法用

np.take
避免这一步,这给了我大量冗余信息?

python numpy linear-algebra
1个回答
1
投票

您可以使用 numpy 的高级索引来完成这个。

旁注,当

np.tensordot
也可以完成这项工作时,
np.inner
可能有点矫枉过正:

import numpy as np

M, N, A = 3, 5, 4
K = 6

gamma = np.random.rand(M, N, A)
b = np.random.rand(K, A)

argmaxes = np.inner(gamma, b).argmax(axis=1) # (M, K)

result = gamma[np.arange(M)[:, None], argmaxes] # (M, K, A)

或者,通过

np.take_along_axis

result = np.take_along_axis(gamma, argmaxes[:, :, None], axis=1)

另请注意,在您的示例实现中,

np.diagonal
在末尾插入新维度。以上与您的代码仅在转置上有所不同。

© www.soinside.com 2019 - 2024. All rights reserved.