使用 Numba JIT 与转置 NumPy 数组进行矩阵乘法不起作用

问题描述 投票:0回答:3

环境

  • 操作系统:Windows 10
  • Python版本:3.10
  • Numba 版本:0.57.0
  • NumPy 版本:1.24.3

示例

import numpy as np
from numba import njit

@njit
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    # return a @ b.T  # also tried this with a similar result, np.matmul seems to be unsupported by Numba
    return a.dot(b.transpose())

matmul_transposed(np.array([[1.0, 1.0]]), np.array([[1.0, 1.0]]))

错误

上面的例子引发了一个错误

- Resolution failure for literal arguments:
No implementation of function Function(<function array_dot at 0x...>) found for signature:
 >>> array_dot(array(float64, 2d, C), array(float64, 2d, F))
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'array_dot': File: numba\np\arrayobj.py: Line 5929.
    With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   No implementation of function Function(<function dot at 0x...>) found for signature:
    >>> dot(array(float64, 2d, C), array(float64, 2d, F))
   There are 4 candidate implementations:
         - Of which 2 did not match due to:
         Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
           With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
          Rejected as the implementation raised a specific error:
            LoweringError: Failed in nopython mode pipeline (step: native lowering)
          scipy 0.16+ is required for linear algebra
          
          File "[...]\numba\np\linalg.py", line 582:
                      def _dot2_codegen(context, builder, sig, args):
                          <source elided>
                  return lambda left, right: _impl(left, right)
                  ^
          
          During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at [...]\numba\np\linalg.py (582)
     raised from [...]\numba\core\errors.py:837
         - Of which 2 did not match due to:
         Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
           With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
          Rejected as the implementation raised a specific error:
            TypingError: missing a required argument: 'out'
     raised from [...]\numba\core\typing\templates.py:784
   
   During: resolving callee type: Function(<function dot at 0x...>)
   During: typing of call at [...]\numba\np\arrayobj.py (5932)
   
   
   File "[...]\numba\np\arrayobj.py", line 5932:
       def dot_impl(arr, other):
           return np.dot(arr, other)
           ^
  raised from [...]\numba\core\typeinfer.py:1086
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.npytypes.Array'>, 'dot') for array(float64, 2d, C))
During: typing of call at [...]\example.py (7)
File "scratch_2.py", line 7:
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    <source elided>
    """Return a @ b.T"""
    return a.dot(b.transpose())
    ^

解读

从错误消息中我得出结论,Numba 似乎通过将其布局样式从 C 更改为 Fotran 来转置数组,这是一种廉价的操作,因为它不必物理上更改布局,但它似乎不知道如何乘以 C -style 数组和 Fotrtran 样式数组。

问题

有什么方法可以将这些矩阵相乘吗?最好在进行转置时不复制整个

b

这似乎是一个相当普通的操作,所以我很困惑它不起作用。

python numpy matrix matrix-multiplication numba
3个回答
1
投票

你的解释并没有离谱:numba 有四个候选者来乘以 C 和 F 布局数组,并给你详细信息,为什么最后没有选择每一个。后两者因缺少参数而被忽略,因此它们显然是用于另一个调用签名。前两者被解雇是因为内部有些东西不起作用:

LoweringError: Failed in nopython mode pipeline (step: native lowering)
          scipy 0.16+ is required for linear algebra

虽然第一行非常神秘,但第二行仍然是错误消息的一部分,并给出了很好的提示。手动安装

scipy
就可以了。

旁注:这基本上是 numpy 函数的一行,它应该在单个 CPU 核心上表现得很好,因为 numba 不需要消除太多 Python 开销。当然,这取决于您还想做什么,但不要指望这一单品会获得显着的提升。


0
投票

我正在做更复杂的计算,但这不起作用。 sassa状态检查


-1
投票

我正在做更复杂的hesgoals计算,但这不起作用。

© www.soinside.com 2019 - 2024. All rights reserved.