import numpy as np
from numba import njit
@njit
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# return a @ b.T # also tried this with a similar result, np.matmul seems to be unsupported by Numba
return a.dot(b.transpose())
matmul_transposed(np.array([[1.0, 1.0]]), np.array([[1.0, 1.0]]))
上面的例子引发了一个错误
- Resolution failure for literal arguments:
No implementation of function Function(<function array_dot at 0x...>) found for signature:
>>> array_dot(array(float64, 2d, C), array(float64, 2d, F))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'array_dot': File: numba\np\arrayobj.py: Line 5929.
With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function dot at 0x...>) found for signature:
>>> dot(array(float64, 2d, C), array(float64, 2d, F))
There are 4 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
Rejected as the implementation raised a specific error:
LoweringError: Failed in nopython mode pipeline (step: native lowering)
scipy 0.16+ is required for linear algebra
File "[...]\numba\np\linalg.py", line 582:
def _dot2_codegen(context, builder, sig, args):
<source elided>
return lambda left, right: _impl(left, right)
^
During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at [...]\numba\np\linalg.py (582)
raised from [...]\numba\core\errors.py:837
- Of which 2 did not match due to:
Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
Rejected as the implementation raised a specific error:
TypingError: missing a required argument: 'out'
raised from [...]\numba\core\typing\templates.py:784
During: resolving callee type: Function(<function dot at 0x...>)
During: typing of call at [...]\numba\np\arrayobj.py (5932)
File "[...]\numba\np\arrayobj.py", line 5932:
def dot_impl(arr, other):
return np.dot(arr, other)
^
raised from [...]\numba\core\typeinfer.py:1086
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.npytypes.Array'>, 'dot') for array(float64, 2d, C))
During: typing of call at [...]\example.py (7)
File "scratch_2.py", line 7:
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
<source elided>
"""Return a @ b.T"""
return a.dot(b.transpose())
^
从错误消息中我得出结论,Numba 似乎通过将其布局样式从 C 更改为 Fotran 来转置数组,这是一种廉价的操作,因为它不必物理上更改布局,但它似乎不知道如何乘以 C -style 数组和 Fotrtran 样式数组。
有什么方法可以将这些矩阵相乘吗?最好在进行转置时不复制整个
b
?
这似乎是一个相当普通的操作,所以我很困惑它不起作用。
你的解释并没有离谱:numba 有四个候选者来乘以 C 和 F 布局数组,并给你详细信息,为什么最后没有选择每一个。后两者因缺少参数而被忽略,因此它们显然是用于另一个调用签名。前两者被解雇是因为内部有些东西不起作用:
LoweringError: Failed in nopython mode pipeline (step: native lowering)
scipy 0.16+ is required for linear algebra
虽然第一行非常神秘,但第二行仍然是错误消息的一部分,并给出了很好的提示。手动安装
scipy
就可以了。
旁注:这基本上是 numpy 函数的一行,它应该在单个 CPU 核心上表现得很好,因为 numba 不需要消除太多 Python 开销。当然,这取决于您还想做什么,但不要指望这一单品会获得显着的提升。
我正在做更复杂的计算,但这不起作用。 sassa状态检查
我正在做更复杂的hesgoals计算,但这不起作用。