将数组参数发送到 numba 函数

问题描述 投票:0回答:1

我有以下代码,它将数组作为参数发送到 numba 函数:

import numpy as np
from numba import njit, float64

A = [( 0.0182286178413157, -1.2904019395416308),
 ( 0.5228683581098151,  0.2323207738837293),
 (-0.6056770113345468,  1.5990251249135883),
 (-0.7557841434090988,  1.4641641762952791),
 ( 0.9882455737412416, -1.1838797980930709),
 (-1.2168205368640061,  1.5178083863904257),
 (-0.5566781056044838,  0.2160324328998916),
 ( 0.0671405605855369, -0.4246242749812621),
 ( 0.4806167193998933,  1.0521631181457611),
 ( 0.0563547059786364, -0.8223422191733811)]

A = np.array(A)

@njit(float64(float64[:]))
def distance(a):           
    return a[0]**2 + a[1]**2 + 2*a[0]*a[1]

distance(A)

我无法获取使此代码运行的签名字符串(仅适用于标量参数)。总是出现这个错误:

TypeError: No matching definition for argument type(s) array(float64, 2d, C)
python numba
1个回答
0
投票

你应该这样重写函数:

@njit
def distance(A):
    return A[:, 0]**2 + A[:, 1]**2 + 2*A[:, 0]*A[:, 1]
© www.soinside.com 2019 - 2024. All rights reserved.