各位程序员朋友们好!
我想使用这段代码,但是 np.hstack
函数似乎与 numba.njit
装饰者。
import numpy as np
import numba
@numba.njit
def main():
J_1 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
J_2 = np.array([[-85.33333333, 34.13333333, 34.13333333, 17.06666667], [34.13333333, -34.13333333, 0., 0.], [34.13333333, 0., -34.13333333, 0.], [17.06666667, 0., 0., -870.4]])
J_3 = np.array([[85.33333333, -34.13333333, -34.13333333, -17.06666667], [-34.13333333, 34.13333333, -0., -0.], [-34.13333333, -0., 34.13333333, -0.], [-17.06666667, -0., -0., 870.4]])
J_4 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
J_old = [[J_1, J_2], [J_3, J_4]]
J_stack = np.hstack(J_old[0])
for row in J_old[1:]:
col = np.hstack(row)
J = np.vstack((J_stack, col))
print(J)
if __name__ == '__main__':
main()
输出:
C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 19, in <module>
main()
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function hstack at 0x000001981A60B558>) with argument(s) of type(s): (list(array(float64, 2d, C)))
* parameterized
In definition 0:
TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))
raised from C:\Users\Artur\Anaconda\lib\site-packages\numba\core\typing\npydecl.py:779
In definition 1:
TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))
raised from C:\Users\Artur\Anaconda\lib\site-packages\numba\core\typing\npydecl.py:779
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function hstack at 0x000001981A60B558>)
[2] During: typing of call at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (11)
File "test2.py", line 11:
def main():
<source elided>
J_old = [[J_1, J_2], [J_3, J_4]]
J_stack = np.hstack(J_old[0])
^
Process finished with exit code 1
原本这个片段,
J_old = [[J_1, J_2], [J_3, J_4]]
J_stack = np.hstack(J_old[0])
for row in J_old[1:]:
col = np.hstack(row)
J = np.vstack((J_stack, col))
是替换了 J = np.bmat([[J_1, J_2], [J_3, J_4]])
的,但却不能与 numba.njit
装饰师也好。
np.hstack
是numba.njit装饰器中的一个 支持的numpy功能,错误信息中明确指出了其他内容。作为一个简单的解决方法,你可以在你的四块作业后使用下面的单行线,以构建 J
(在numba上测试) 0.48.0
):
J = np.vstack((np.hstack((J_1, J_2)),np.hstack((J_3, J_4))))
这给出的结果相当于 np.bmat
.
希望这能帮助你。
从错误信息来看
TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))
我们看到的问题是:农巴版的。hstack
期待一个数组的元组,而你给了它一个数组的列表。(NumPy版本的 hstack
是比较宽容的,会让你使用一个列表)。)
这可以通过在你的 J_old
:
J_old = [(J_1, J_2), (J_3, J_4)]
在Numba中,更普遍的是尽可能使用元组,因为对列表的支持是存在的,但相当不完整,许多函数对列表并不满意,即使在技术上可以(虽然不是习惯性的)使用列表作为NumPy的参数。
(当然,Yacola的解决方案更多的是一种改进--我只是想指出,这是Numba工作所需的最小改动。)