numpy.hstack是numba.njit的替代品

问题描述 投票:0回答:1

各位程序员朋友们好!

我想使用这段代码,但是 np.hstack 函数似乎与 numba.njit 装饰者。

import numpy as np
import numba

@numba.njit
def main():
    J_1 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
    J_2 = np.array([[-85.33333333, 34.13333333, 34.13333333, 17.06666667], [34.13333333, -34.13333333, 0., 0.], [34.13333333, 0., -34.13333333, 0.], [17.06666667, 0., 0., -870.4]])
    J_3 = np.array([[85.33333333, -34.13333333, -34.13333333, -17.06666667], [-34.13333333, 34.13333333, -0., -0.], [-34.13333333, -0., 34.13333333, -0.], [-17.06666667, -0., -0., 870.4]])
    J_4 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
    J_old = [[J_1, J_2], [J_3, J_4]]
    J_stack = np.hstack(J_old[0])
    for row in J_old[1:]:
        col = np.hstack(row)
        J = np.vstack((J_stack, col))

    print(J)

if __name__ == '__main__':
    main()

输出:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 19, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function hstack at 0x000001981A60B558>) with argument(s) of type(s): (list(array(float64, 2d, C)))
 * parameterized
In definition 0:
    TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))
    raised from C:\Users\Artur\Anaconda\lib\site-packages\numba\core\typing\npydecl.py:779
In definition 1:
    TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))
    raised from C:\Users\Artur\Anaconda\lib\site-packages\numba\core\typing\npydecl.py:779
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function hstack at 0x000001981A60B558>)
[2] During: typing of call at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (11)


File "test2.py", line 11:
def main():
    <source elided>
    J_old = [[J_1, J_2], [J_3, J_4]]
    J_stack = np.hstack(J_old[0])
    ^


Process finished with exit code 1

原本这个片段,

J_old = [[J_1, J_2], [J_3, J_4]]
J_stack = np.hstack(J_old[0])
for row in J_old[1:]:
    col = np.hstack(row)
    J = np.vstack((J_stack, col))

是替换了 J = np.bmat([[J_1, J_2], [J_3, J_4]]) 的,但却不能与 numba.njit 装饰师也好。

python numpy jit numba
1个回答
1
投票

np.hstack 是numba.njit装饰器中的一个 支持的numpy功能,错误信息中明确指出了其他内容。作为一个简单的解决方法,你可以在你的四块作业后使用下面的单行线,以构建 J (在numba上测试) 0.48.0):

J = np.vstack((np.hstack((J_1, J_2)),np.hstack((J_3, J_4))))

这给出的结果相当于 np.bmat.

希望这能帮助你。


0
投票

从错误信息来看

TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))

我们看到的问题是:农巴版的。hstack 期待一个数组的元组,而你给了它一个数组的列表。(NumPy版本的 hstack 是比较宽容的,会让你使用一个列表)。)

这可以通过在你的 J_old:

J_old = [(J_1, J_2), (J_3, J_4)]

在Numba中,更普遍的是尽可能使用元组,因为对列表的支持是存在的,但相当不完整,许多函数对列表并不满意,即使在技术上可以(虽然不是习惯性的)使用列表作为NumPy的参数。

(当然,Yacola的解决方案更多的是一种改进--我只是想指出,这是Numba工作所需的最小改动。)

© www.soinside.com 2019 - 2024. All rights reserved.