我试图创建更高效的代码,但在实现下面的Numba版本时被卡住了。
import numpy as np
a = np.array([[0, 0, 0, 0],
[0, 0, 0, 0]])
bool_idx = np.array([True, False, False, True])
a[0, bool_idx] += 3
a
array([[3, 0, 0, 3],
[0, 0, 0, 0]])
不幸的是,当我把这段代码移植到numba函数中时,我得到了一个错误。
@njit
def add_to_arr(a, idx, arr_bool, add):
arr[idx, arr_bool] += 3
return arr
add_to_arr(a=a, idx=0, arr_bool=bool_idx, add=3)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(int32, 2d, C), (int64, array(bool, 1d, C)))
在这种情况下,Numba似乎只允许在数组的第一个维度上进行高级索引。我们可以重写函数(也可以纠正一个轻微的错别字)来适应这种情况,只需使用转置和反转索引。
@njit
def add_to_arr(a, idx, arr_bool, add):
a.T[arr_bool, idx] += 3
return a
add_to_arr(a, 0, bool_idx, 3)
这对我来说是可行的,结果是:
array([[3, 0, 0, 3],
[0, 0, 0, 0]])
得到: 文件 说高级索引只允许在一个维度中使用,但没有指定这个维度需要是第一个维度,所以这可能是一个bug。