Python中3维Numpy ndarray的高级索引

问题描述 投票:0回答:1

我有一个形状为(68、64、64)的ndarray,称为“预测”。这些尺寸对应于图像编号,高度,宽度。对于每个图像,我都有一个长度为2的元组,其中包含与每个64x64图像中的特定位置相对应的坐标,例如(12,45)。我可以将这些坐标堆叠到另一个形状为(68,2)的Numpy ndarray中,称为“位置”。

如何在不使用循环的情况下构造切片对象或构造必要的高级索引索引以访问这些位置?寻找有关语法的帮助。目标是使用无循环的纯Numpy矩阵。

工作循环结构

Import numpy as np
# example code with just ones...The real arrays have 'real' data.
prediction = np.ones((68,64,64), dtype='float32')
locations = np.ones((68,2), dtype='uint32')

selected_location_values = np.empty(prediction.shape[0], dtype='float32')
for index, (image, coordinates) in enumerate(zip(prediction, locations)):
    selected_locations_values[index] = image[coordinates]

所需方法

selected_location_values = np.empty(prediction.shape[0], dtype='float32')

correct_indexing = some_function_here(locations). # ?????
selected_locations_values = predictions[correct_indexing]
python numpy numpy-broadcasting numpy-indexing
1个回答
0
投票

一个简单的索引应该起作用:

img = np.arange(locations.shape[0])
r = locations[:, 0]
c = locations[:, 1]
selected_locations_values = predictions[img, r, c]

花式索引通过选择与广播索引的形状相对应的索引数组元素来工作。在这种情况下,索引非常简单。您只需要该范围即可告诉您每个位置对应的图像。

© www.soinside.com 2019 - 2024. All rights reserved.