我有一些代码可以查找“输入”列表中与“值”列表中的任何元素匹配的元素的索引。然后输出索引,按照与“值”列表相同的顺序排列。
input = [1, 2, 8, 7, 3, 4, 6, 5, 9]
values = [4, 8, 3]
match_index_lst, match_index_values = np.where(np.array(input) == np.array(values)[:,None])
output_indice_lst = match_index_values[np.argsort(match_index_lst)]
# [5, 2, 4]
我的问题是,是否可以有效地(使用矢量化操作)扩展此代码以在特定的多维列表中使用它?目前输入列表的尺寸为
c
,但我将拥有尺寸为 [a, b, c]
的输入。所以不要像这样:
input = [1, 2, 8, 7, 3, 4, 6, 5, 9]
values = [4, 8, 3]
# output: [5, 2, 4]
我会有类似的东西
input = [[[[ 0.31, 1.56, 1.58, 0.16, 0.22, 0.54, 0.98, 0.35 ]],
[[ 0.77, 2.62, 0.44, 0.08, 0.76, 0.87, 0.88, 0.51 ]]],
[[[ 1.14, 0.48, 1.09, 0.93, 0.47, 0.13, 0.75, 0.19 ]],
[[ 1.15, 0.17, 2.33, 0.46, 0.30, 2.60, 0.79, 1.07 ]]]]
values = [[[[ 0.54, 1.58 ]],
[[ 0.77, 0.88 ]]],
[[[ 0.48, 1.09 ]],
[[ 2.60, 2.33 ]]]]
# output: [[[[ 5, 2 ]],
# [[ 0, 6 ]]],
#
# [[[ 1, 2 ]],
# [[ 5, 2 ]]]]
我的具体示例是尺寸
(2, 2, 8)
,但它可以是任何 (a,b,c)
尺寸。
我尝试过将其展平,然后对其进行操作,但在展平它之后我似乎无法正确获得顺序,然后正确格式化输出也是一场噩梦。我可以看到使用 for 循环实现起来非常容易,但我想将其作为最后的手段,因为速度至关重要。
请注意,将
np.where
与广播 ==
一起使用效率非常低。考虑使用searchsorted
:
def match(arr, vals):
k = arr.shape[-1]
def fn(x):
x, y = x[:k], x[k:]
idx = x.argsort()
return idx[np.searchsorted(x[idx], y)]
return np.apply_along_axis(fn, len(arr.shape) - 1, np.c_[arr,vals])
match(np.array(input_arr), np.array(values))
array([[[5, 2],
[0, 1]],
[[1, 0],
[5, 5]]], dtype=int64)
请注意,您可以使用
lambda
功能,如下所示:
def match(arr, vals):
k = arr.shape[-1]
fn = lambda x : (i1 := x[:k].argsort())[np.searchsorted(x[:k][i1], x[k:])]
return np.apply_along_axis(fn, len(arr.shape) - 1, np.c_[arr,vals])
match(np.array(input_arr), np.array(values))
array([[[5, 2],
[0, 1]],
[[1, 0],
[5, 5]]], dtype=int64)
我认为您使用扁平索引的想法是正确的。看起来是这样的:
import numpy as np
input = np.array([[[[ 0.31, 1.56, 1.58, 0.16, 0.22, 0.54, 0.98, 0.35],
[ 0.77, 2.62, 0.44, 0.08, 0.76, 0.87, 0.87, 0.51]],
[[ 1.14, 0.48, 1.08, 0.93, 0.47, 0.13, 0.75, 0.19 ],
[ 1.15, 0.17, 2.32, 0.46, 0.30, 2.60, 0.79, 1.07 ]]]])
values = np.array([[[[ 0.54, 1.58 ]],
[[ 0.77, 0.88 ]]],
[[[ 0.48, 1.09 ]],
[[ 2.60, 2.33 ]]]])
sort_idx = np.argsort(input.flat)
output_flat = sort_idx[np.searchsorted(input.flat, values.flat, sorter=sort_idx)]
output = np.unravel_index(output_flat.reshape(values.shape), input.shape)[-1]
print(output)
哪个打印:
[[[[5 2]]
[[0 3]]]
[[[1 0]]
[[5 5]]]]
(请检查您的参考输出,它对我来说似乎不正确,特别是有一个问题,如果该值不在索引中应该发生什么)
关键的缺失部分是
np.unravel_index()
并重塑为 values.shape
。
我希望这有帮助!