使用广播将单个向量与向量数组相乘时出现 Numba 输入错误

问题描述 投票:0回答:1

我在将

numba
应用于我试图优化性能的一组函数时遇到问题。所有函数在没有
numba
的情况下都可以正常工作,但是当我尝试使用
numba
时出现编译错误。

这是我正在努力解决的编译错误:

Exception occurred:
Type: TypingError
Message: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(float64, 2d, C) and array(float64, 1d, C) for 'q1.2', defined at .\rotations.py (82)

File "rotations.py", line 82:
def quaternion_mult(q1, qa):
    <source elided>

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    ^

During: typing of assignment at .\rotations.py (82)

File "rotations.py", line 82:
def quaternion_mult(q1, qa):
    <source elided>

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    ^

During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)

During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)


File "rotations.py", line 102:
def quaternion_vect_mult(q1, vect_array):
    <source elided>

    temp = quaternion_mult(q1, q2)
    ^

这是相应函数的完整代码:


@njit(cache=True)
def quaternion_conjugate_vect(q):
    """
    return the conjugate of a quaternion or an array of quaternions
    """
    return q * np.array([1, -1, -1, -1])


@njit(cache=True)
def quaternion_mult(q1, qa):
    """
    multiply an array of quaternions (Nx4) by a single quaternion.

    qa is always a (Nx4) array of quaternions np.ndarray
    q1 is always a single (1x4) quaternion np.ndarray

    """
    N = max(len(qa), len(q1))
    quat_result = np.zeros((N, 4), dtype=np.float64)

    if qa.ndim == 1:
        q2 = qa.copy().reshape((1, -1))
        # q2 = np.reshape(q1, (1,-1))
    else:
        q2 = qa

    if q1.ndim == 1:
        # q1 = q1.copy().reshape((1, -1))
        q1 = np.reshape(q1, (1, -1))

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    quat_result[:, 1] = (q1[:, 0] * q2[:, 1]) + (q1[:, 1] * q2[:, 0]) + (q1[:, 2] * q2[:, 3]) - (q1[:, 3] * q2[:, 2])
    quat_result[:, 2] = (q1[:, 0] * q2[:, 2]) + (q1[:, 2] * q2[:, 0]) + (q1[:, 3] * q2[:, 1]) - (q1[:, 1] * q2[:, 3])
    quat_result[:, 3] = (q1[:, 0] * q2[:, 3]) + (q1[:, 3] * q2[:, 0]) + (q1[:, 1] * q2[:, 2]) - (q1[:, 2] * q2[:, 1])

    return quat_result


@njit(cache=True)
def quaternion_vect_mult(q1, vect_array):
    """
    Multiplies an array of x,y,z coordinates by a single quaternion q1.
    """
    # q1 is the quaternion which the coordinates will be rotated by.

    # Add initial column of zeros to array
    # N = len(vect_array)
    q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
    q2[:, 1::] = vect_array

    temp = quaternion_mult(q1, q2)
    result = quaternion_mult(temp, quaternion_conjugate_vect(q1))

    return result[:, 1::]

我不明白统一错误,因为我在乘法中广播,所以形状应该是无关的?所有数组都是

np.float64
所以我将其指定为类型。唯一的区别是形状,但正常的
numpy
广播应该在这里工作,就像没有
numba
一样。 (我添加了一些额外的括号以确保我正确地相乘,但根本不需要它们。)

我认为问题与

np.zeros
存储数组的创建有关,我已经添加了这个,因为之前我单独计算了每一列,然后与
np.stack
组合。

我唯一的其他想法是它与

if ... else...
有关,我在其中检查单个四元数是否为
shape
(1,4)
而不是
(,4)

我对此有点困惑,其他类似的问题通常似乎涉及一些类型差异,例如

int
float
float32
float64

如有任何帮助,我们将不胜感激。

为了清楚起见,下面是一个在没有

numba
的情况下工作但在启用它时失败的示例:

from numba import njit
import numpy as np

quat_single = np.random.random(,4)
coord_array = np.random.random([9,3])

Note: quat_single = np.random.random([1,4]) will work with `numba`


quaternion_vect_mult(quat_single, coord_array)
Out[18]: 
array([[ 0.12035005,  1.51894951,  0.26731225],
       [ 1.56889141,  0.56465019,  0.18818138],
       [ 0.58966646,  1.09653585, -0.19548354],
       [ 1.15044012,  1.56034916,  0.73943456],
       [ 0.83003034,  1.80861828,  0.02678796],
       [ 1.15572912,  0.54263501,  0.16206597],
       [ 1.34243762,  1.0802315 , -0.20735991],
       [ 1.5876305 ,  0.70017144,  0.80066164],
       [ 1.20734218,  1.2747372 , -0.47177605]])

python arrays numpy numba array-broadcasting
1个回答
0
投票

用这些行:

    temp = quaternion_mult(q1, q2)
    result = quaternion_mult(temp, quaternion_conjugate_vect(q1))

你每次都给出不同的参数类型,所以 numba 很困惑如何编译这个函数。

为您想要支持的每个参数类型/维度分别创建 

quaternion_mult

,例如:

quaternion_mult

与:

@njit(cache=True) def quaternion_conjugate_vect(q): """ return the conjugate of a quaternion or an array of quaternions """ return q * np.array([1, -1, -1, -1]) @njit(cache=True) def quaternion_mult1(q1, qa): """ multiply an array of quaternions (Nx4) by a single quaternion. qa is always a (Nx4) array of quaternions np.ndarray q1 is always a single (1x4) quaternion np.ndarray """ N = max(len(qa), len(q1)) quat_result = np.zeros((N, 4), dtype=np.float64) # if qa.ndim == 1: # q2 = qa.copy().reshape((1, -1)) # # q2 = np.reshape(q1, (1,-1)) # else: # q2 = qa q2 = qa # if q1.ndim == 1: # # q1 = q1.copy().reshape((1, -1)) # q1 = np.reshape(q1, (1, -1)) quat_result[:, 0] = ( (q1[0] * q2[:, 0]) - (q1[1] * q2[:, 1]) - (q1[2] * q2[:, 2]) - (q1[3] * q2[:, 3]) ) quat_result[:, 1] = ( (q1[0] * q2[:, 1]) + (q1[1] * q2[:, 0]) + (q1[2] * q2[:, 3]) - (q1[3] * q2[:, 2]) ) quat_result[:, 2] = ( (q1[0] * q2[:, 2]) + (q1[2] * q2[:, 0]) + (q1[3] * q2[:, 1]) - (q1[1] * q2[:, 3]) ) quat_result[:, 3] = ( (q1[0] * q2[:, 3]) + (q1[3] * q2[:, 0]) + (q1[1] * q2[:, 2]) - (q1[2] * q2[:, 1]) ) return quat_result @njit(cache=True) def quaternion_mult2(q1, qa): N = max(len(qa), len(q1)) quat_result = np.zeros((N, 4), dtype=np.float64) q2 = qa.copy().reshape((1, -1)) quat_result[:, 0] = ( (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3]) ) quat_result[:, 1] = ( (q1[:, 0] * q2[:, 1]) + (q1[:, 1] * q2[:, 0]) + (q1[:, 2] * q2[:, 3]) - (q1[:, 3] * q2[:, 2]) ) quat_result[:, 2] = ( (q1[:, 0] * q2[:, 2]) + (q1[:, 2] * q2[:, 0]) + (q1[:, 3] * q2[:, 1]) - (q1[:, 1] * q2[:, 3]) ) quat_result[:, 3] = ( (q1[:, 0] * q2[:, 3]) + (q1[:, 3] * q2[:, 0]) + (q1[:, 1] * q2[:, 2]) - (q1[:, 2] * q2[:, 1]) ) return quat_result @njit(cache=True) def quaternion_vect_mult(q1, vect_array): """ Multiplies an array of x,y,z coordinates by a single quaternion q1. """ # q1 is the quaternion which the coordinates will be rotated by. # Add initial column of zeros to array # N = len(vect_array) q2 = np.zeros((len(vect_array), 4), dtype=np.float64) q2[:, 1::] = vect_array temp = quaternion_mult1(q1, q2) result = quaternion_mult2(temp, quaternion_conjugate_vect(q1)) return result[:, 1::]

打印:

np.random.seed(42) quat_single = np.random.random(4) coord_array = np.random.random([9, 3]) print(quaternion_vect_mult(quat_single, coord_array))

根据我的“基准”,它应该比非 jitted 版本快约 30-40 倍。

© www.soinside.com 2019 - 2024. All rights reserved.