我在将
numba
应用于我试图优化性能的一组函数时遇到问题。所有函数在没有 numba
的情况下都可以正常工作,但是当我尝试使用 numba
时出现编译错误。
这是我正在努力解决的编译错误:
Exception occurred:
Type: TypingError
Message: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(float64, 2d, C) and array(float64, 1d, C) for 'q1.2', defined at .\rotations.py (82)
File "rotations.py", line 82:
def quaternion_mult(q1, qa):
<source elided>
quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
^
During: typing of assignment at .\rotations.py (82)
File "rotations.py", line 82:
def quaternion_mult(q1, qa):
<source elided>
quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
^
During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)
During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)
File "rotations.py", line 102:
def quaternion_vect_mult(q1, vect_array):
<source elided>
temp = quaternion_mult(q1, q2)
^
这是相应函数的完整代码:
@njit(cache=True)
def quaternion_conjugate_vect(q):
"""
return the conjugate of a quaternion or an array of quaternions
"""
return q * np.array([1, -1, -1, -1])
@njit(cache=True)
def quaternion_mult(q1, qa):
"""
multiply an array of quaternions (Nx4) by a single quaternion.
qa is always a (Nx4) array of quaternions np.ndarray
q1 is always a single (1x4) quaternion np.ndarray
"""
N = max(len(qa), len(q1))
quat_result = np.zeros((N, 4), dtype=np.float64)
if qa.ndim == 1:
q2 = qa.copy().reshape((1, -1))
# q2 = np.reshape(q1, (1,-1))
else:
q2 = qa
if q1.ndim == 1:
# q1 = q1.copy().reshape((1, -1))
q1 = np.reshape(q1, (1, -1))
quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
quat_result[:, 1] = (q1[:, 0] * q2[:, 1]) + (q1[:, 1] * q2[:, 0]) + (q1[:, 2] * q2[:, 3]) - (q1[:, 3] * q2[:, 2])
quat_result[:, 2] = (q1[:, 0] * q2[:, 2]) + (q1[:, 2] * q2[:, 0]) + (q1[:, 3] * q2[:, 1]) - (q1[:, 1] * q2[:, 3])
quat_result[:, 3] = (q1[:, 0] * q2[:, 3]) + (q1[:, 3] * q2[:, 0]) + (q1[:, 1] * q2[:, 2]) - (q1[:, 2] * q2[:, 1])
return quat_result
@njit(cache=True)
def quaternion_vect_mult(q1, vect_array):
"""
Multiplies an array of x,y,z coordinates by a single quaternion q1.
"""
# q1 is the quaternion which the coordinates will be rotated by.
# Add initial column of zeros to array
# N = len(vect_array)
q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
q2[:, 1::] = vect_array
temp = quaternion_mult(q1, q2)
result = quaternion_mult(temp, quaternion_conjugate_vect(q1))
return result[:, 1::]
我不明白统一错误,因为我在乘法中广播,所以形状应该是无关的?所有数组都是
np.float64
所以我将其指定为类型。唯一的区别是形状,但正常的 numpy
广播应该在这里工作,就像没有 numba
一样。 (我添加了一些额外的括号以确保我正确地相乘,但根本不需要它们。)
我认为问题与
np.zeros
存储数组的创建有关,我已经添加了这个,因为之前我单独计算了每一列,然后与 np.stack
组合。
我唯一的其他想法是它与
if ... else...
有关,我在其中检查单个四元数是否为 shape
(1,4)
而不是 (,4)
。
我对此有点困惑,其他类似的问题通常似乎涉及一些类型差异,例如
int
和 float
或 float32
和 float64
。
如有任何帮助,我们将不胜感激。
为了清楚起见,下面是一个在没有
numba
的情况下工作但在启用它时失败的示例:
from numba import njit
import numpy as np
quat_single = np.random.random(,4)
coord_array = np.random.random([9,3])
Note: quat_single = np.random.random([1,4]) will work with `numba`
quaternion_vect_mult(quat_single, coord_array)
Out[18]:
array([[ 0.12035005, 1.51894951, 0.26731225],
[ 1.56889141, 0.56465019, 0.18818138],
[ 0.58966646, 1.09653585, -0.19548354],
[ 1.15044012, 1.56034916, 0.73943456],
[ 0.83003034, 1.80861828, 0.02678796],
[ 1.15572912, 0.54263501, 0.16206597],
[ 1.34243762, 1.0802315 , -0.20735991],
[ 1.5876305 , 0.70017144, 0.80066164],
[ 1.20734218, 1.2747372 , -0.47177605]])
用这些行:
temp = quaternion_mult(q1, q2)
result = quaternion_mult(temp, quaternion_conjugate_vect(q1))
你每次都给出不同的参数类型,所以 numba 很困惑如何编译这个函数。
为您想要支持的每个参数类型/维度分别创建quaternion_mult
,例如:
quaternion_mult
与:
@njit(cache=True)
def quaternion_conjugate_vect(q):
"""
return the conjugate of a quaternion or an array of quaternions
"""
return q * np.array([1, -1, -1, -1])
@njit(cache=True)
def quaternion_mult1(q1, qa):
"""
multiply an array of quaternions (Nx4) by a single quaternion.
qa is always a (Nx4) array of quaternions np.ndarray
q1 is always a single (1x4) quaternion np.ndarray
"""
N = max(len(qa), len(q1))
quat_result = np.zeros((N, 4), dtype=np.float64)
# if qa.ndim == 1:
# q2 = qa.copy().reshape((1, -1))
# # q2 = np.reshape(q1, (1,-1))
# else:
# q2 = qa
q2 = qa
# if q1.ndim == 1:
# # q1 = q1.copy().reshape((1, -1))
# q1 = np.reshape(q1, (1, -1))
quat_result[:, 0] = (
(q1[0] * q2[:, 0])
- (q1[1] * q2[:, 1])
- (q1[2] * q2[:, 2])
- (q1[3] * q2[:, 3])
)
quat_result[:, 1] = (
(q1[0] * q2[:, 1])
+ (q1[1] * q2[:, 0])
+ (q1[2] * q2[:, 3])
- (q1[3] * q2[:, 2])
)
quat_result[:, 2] = (
(q1[0] * q2[:, 2])
+ (q1[2] * q2[:, 0])
+ (q1[3] * q2[:, 1])
- (q1[1] * q2[:, 3])
)
quat_result[:, 3] = (
(q1[0] * q2[:, 3])
+ (q1[3] * q2[:, 0])
+ (q1[1] * q2[:, 2])
- (q1[2] * q2[:, 1])
)
return quat_result
@njit(cache=True)
def quaternion_mult2(q1, qa):
N = max(len(qa), len(q1))
quat_result = np.zeros((N, 4), dtype=np.float64)
q2 = qa.copy().reshape((1, -1))
quat_result[:, 0] = (
(q1[:, 0] * q2[:, 0])
- (q1[:, 1] * q2[:, 1])
- (q1[:, 2] * q2[:, 2])
- (q1[:, 3] * q2[:, 3])
)
quat_result[:, 1] = (
(q1[:, 0] * q2[:, 1])
+ (q1[:, 1] * q2[:, 0])
+ (q1[:, 2] * q2[:, 3])
- (q1[:, 3] * q2[:, 2])
)
quat_result[:, 2] = (
(q1[:, 0] * q2[:, 2])
+ (q1[:, 2] * q2[:, 0])
+ (q1[:, 3] * q2[:, 1])
- (q1[:, 1] * q2[:, 3])
)
quat_result[:, 3] = (
(q1[:, 0] * q2[:, 3])
+ (q1[:, 3] * q2[:, 0])
+ (q1[:, 1] * q2[:, 2])
- (q1[:, 2] * q2[:, 1])
)
return quat_result
@njit(cache=True)
def quaternion_vect_mult(q1, vect_array):
"""
Multiplies an array of x,y,z coordinates by a single quaternion q1.
"""
# q1 is the quaternion which the coordinates will be rotated by.
# Add initial column of zeros to array
# N = len(vect_array)
q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
q2[:, 1::] = vect_array
temp = quaternion_mult1(q1, q2)
result = quaternion_mult2(temp, quaternion_conjugate_vect(q1))
return result[:, 1::]
打印:
np.random.seed(42)
quat_single = np.random.random(4)
coord_array = np.random.random([9, 3])
print(quaternion_vect_mult(quat_single, coord_array))