numba vstack 不适用于数组列表

问题描述 投票:0回答:2

对我来说很奇怪的是,当输入是数组列表时,vstack 不能与 Numba 一起使用,它仅在输入是数组元组时才起作用。示例代码:

@nb.jit(nopython=True)
def stack(items):
    return np.vstack(items)

stack((np.array([1,2,3]), np.array([4,5,6])))

退货

array([[1, 2, 3],
       [4, 5, 6]])

但是

stack([np.array([1,2,3]), np.array([4,5,6])])

抛出错误

TypingError: No implementation of function Function(<function vstack at 0x0000027271963488>) found for signature:
>>>vstack(reflected list(array(int32, 1d, C)))

由于不支持元组,我很难找到解决方法 - 我错过了什么吗?

python-3.x numpy numba
2个回答
1
投票

这是 @hpaulj 提到的解决方法:

stack(tuple([np.array([1,2,3]), np.array([4,5,6])]))

[[1 2 3]
 [4 5 6]]

0
投票

在 numba 中,

vstack
hstack
以及
concatenate
仅支持
tuple
作为输入,而不支持
list

他们说这是因为 numba 在编译过程中无法推断堆叠数组的维度[参考]。但我怀疑它实际上可以,因为你可以手动执行此操作。

您可以通过这种间接方式堆叠

list

from numba import njit, prange
import numpy as np


@njit()
def test_list_stack(i, array_to_be_stacked):
    shape = (i,) + array_to_be_stacked.shape
    list_of_array = [array_to_be_stacked] * i
    stacked_array = np.empty(shape)
    for j in prange(i):
        stacked_array[j] = list_of_array[j]
    return stacked_array


if __name__ == "__main__":
    test_list_stack(10, np.ones((2, 3)))
© www.soinside.com 2019 - 2024. All rights reserved.