Numba在numpy中使用布尔索引添加值的同义词。

问题描述 投票:0回答:1

我试图创建更高效的代码,但在实现下面的Numba版本时被卡住了。

import numpy as np

a = np.array([[0, 0, 0, 0],
              [0, 0, 0, 0]])

bool_idx = np.array([True, False, False, True])

a[0, bool_idx] += 3
a

array([[3, 0, 0, 3],
       [0, 0, 0, 0]])

不幸的是,当我把这段代码移植到numba函数中时,我得到了一个错误。

@njit
def add_to_arr(a, idx, arr_bool, add):
    arr[idx, arr_bool] += 3
    return arr

add_to_arr(a=a, idx=0, arr_bool=bool_idx, add=3)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int32, 2d, C), (int64, array(bool, 1d, C)))
python numpy numba
1个回答
0
投票

在这种情况下,Numba似乎只允许在数组的第一个维度上进行高级索引。我们可以重写函数(也可以纠正一个轻微的错别字)来适应这种情况,只需使用转置和反转索引。

@njit 
def add_to_arr(a, idx, arr_bool, add): 
    a.T[arr_bool, idx] += 3 
    return a 

add_to_arr(a, 0, bool_idx, 3)     

这对我来说是可行的,结果是:

array([[3, 0, 0, 3],
       [0, 0, 0, 0]])

得到: 文件 说高级索引只允许在一个维度中使用,但没有指定这个维度需要是第一个维度,所以这可能是一个bug。

© www.soinside.com 2019 - 2024. All rights reserved.