如何多维化这段从值列表中查找索引的代码?

问题描述 投票:0回答:2

我有一些代码可以查找“输入”列表中与“值”列表中的任何元素匹配的元素的索引。然后输出索引,按照与“值”列表相同的顺序排列。

input = [1, 2, 8, 7, 3, 4, 6, 5, 9]
values = [4, 8, 3]

match_index_lst, match_index_values = np.where(np.array(input) == np.array(values)[:,None])

output_indice_lst = match_index_values[np.argsort(match_index_lst)]
# [5, 2, 4]

我的问题是,是否可以有效地(使用矢量化操作)扩展此代码以在特定的多维列表中使用它?目前输入列表的尺寸为

c
,但我将拥有尺寸为
[a, b, c]
的输入。所以不要像这样:

input = [1, 2, 8, 7, 3, 4, 6, 5, 9]
values = [4, 8, 3]
# output: [5, 2, 4]

我会有类似的东西

input = [[[[ 0.31, 1.56, 1.58, 0.16, 0.22, 0.54, 0.98, 0.35 ]],
          [[ 0.77, 2.62, 0.44, 0.08, 0.76, 0.87, 0.88, 0.51 ]]],

         [[[ 1.14, 0.48, 1.09, 0.93, 0.47, 0.13, 0.75, 0.19 ]],
          [[ 1.15, 0.17, 2.33, 0.46, 0.30, 2.60, 0.79, 1.07 ]]]]

values = [[[[ 0.54, 1.58 ]],
           [[ 0.77, 0.88 ]]],

          [[[ 0.48, 1.09 ]],
           [[ 2.60, 2.33 ]]]]


# output: [[[[ 5, 2 ]],
#           [[ 0, 6 ]]],
#
#          [[[ 1, 2 ]],
#           [[ 5, 2 ]]]]

我的具体示例是尺寸

(2, 2, 8)
,但它可以是任何
(a,b,c)
尺寸。

我尝试过将其展平,然后对其进行操作,但在展平它之后我似乎无法正确获得顺序,然后正确格式化输出也是一场噩梦。我可以看到使用 for 循环实现起来非常容易,但我想将其作为最后的手段,因为速度至关重要。

python arrays list numpy multidimensional-array
2个回答
0
投票

请注意,将

np.where
与广播
==
一起使用效率非常低。考虑使用
searchsorted
:

def match(arr, vals):
    k = arr.shape[-1]
    def fn(x):
        x, y = x[:k], x[k:]
        idx = x.argsort()
        return idx[np.searchsorted(x[idx], y)]
    return np.apply_along_axis(fn, len(arr.shape) - 1, np.c_[arr,vals])

match(np.array(input_arr), np.array(values))
array([[[5, 2],
        [0, 1]],

       [[1, 0],
        [5, 5]]], dtype=int64)

请注意,您可以使用

lambda
功能,如下所示:

def match(arr, vals):
    k = arr.shape[-1]
    fn = lambda x : (i1 := x[:k].argsort())[np.searchsorted(x[:k][i1], x[k:])]
    return np.apply_along_axis(fn, len(arr.shape) - 1, np.c_[arr,vals])

match(np.array(input_arr), np.array(values))
array([[[5, 2],
        [0, 1]],

       [[1, 0],
        [5, 5]]], dtype=int64)

0
投票

我认为您使用扁平索引的想法是正确的。看起来是这样的:

import numpy as np

input = np.array([[[[ 0.31, 1.56, 1.58, 0.16, 0.22, 0.54, 0.98, 0.35],
                    [ 0.77, 2.62, 0.44, 0.08, 0.76, 0.87, 0.87, 0.51]],
                   [[ 1.14, 0.48, 1.08, 0.93, 0.47, 0.13, 0.75, 0.19 ],
                    [ 1.15, 0.17, 2.32, 0.46, 0.30, 2.60, 0.79, 1.07 ]]]])

values = np.array([[[[ 0.54, 1.58 ]],
                    [[ 0.77, 0.88 ]]],
                   [[[ 0.48, 1.09 ]],
                    [[ 2.60, 2.33 ]]]])

sort_idx = np.argsort(input.flat)
output_flat = sort_idx[np.searchsorted(input.flat, values.flat, sorter=sort_idx)]
output = np.unravel_index(output_flat.reshape(values.shape), input.shape)[-1]
print(output)

哪个打印:

[[[[5 2]]

  [[0 3]]]


 [[[1 0]]

  [[5 5]]]]

(请检查您的参考输出,它对我来说似乎不正确,特别是有一个问题,如果该值不在索引中应该发生什么)

关键的缺失部分是

np.unravel_index()
并重塑为
values.shape

我希望这有帮助!

© www.soinside.com 2019 - 2024. All rights reserved.