对我来说很奇怪的是,当输入是数组列表时,vstack 不能与 Numba 一起使用,它仅在输入是数组元组时才起作用。示例代码:
@nb.jit(nopython=True)
def stack(items):
return np.vstack(items)
stack((np.array([1,2,3]), np.array([4,5,6])))
退货
array([[1, 2, 3],
[4, 5, 6]])
但是
stack([np.array([1,2,3]), np.array([4,5,6])])
抛出错误
TypingError: No implementation of function Function(<function vstack at 0x0000027271963488>) found for signature:
>>>vstack(reflected list(array(int32, 1d, C)))
由于不支持元组,我很难找到解决方法 - 我错过了什么吗?
这是 @hpaulj 提到的解决方法:
stack(tuple([np.array([1,2,3]), np.array([4,5,6])]))
[[1 2 3]
[4 5 6]]
在 numba 中,
vstack
、hstack
以及 concatenate
仅支持 tuple
作为输入,而不支持 list
。
他们说这是因为 numba 在编译过程中无法推断堆叠数组的维度[参考]。但我怀疑它实际上可以,因为你可以手动执行此操作。
您可以通过这种间接方式堆叠
list
:
from numba import njit, prange
import numpy as np
@njit()
def test_list_stack(i, array_to_be_stacked):
shape = (i,) + array_to_be_stacked.shape
list_of_array = [array_to_be_stacked] * i
stacked_array = np.empty(shape)
for j in prange(i):
stacked_array[j] = list_of_array[j]
return stacked_array
if __name__ == "__main__":
test_list_stack(10, np.ones((2, 3)))