以下代码可以正确编译并执行:
import numpy as np
from numba import njit
Particle = np.dtype([ ('position', 'f4'), ('velocity', 'f4')])
arr = np.zeros(2, dtype=Particle)
@njit
def f(x):
x[0]['position'] = x[1]['position'] + x[1]['velocity'] * 0.2 + 1.
f(arr)
但是,使数据类型具有更高的维度会导致此代码在编译时失败(但无需
@njit
也能工作):
import numpy as np
from numba import njit
Particle = np.dtype([
('position', 'f4', (2,)),
('velocity', 'f4', (2,))
])
arr = np.zeros(2, dtype=Particle)
@njit
def f(x):
x[0]['position'] = x[1]['position'] + x[1]['velocity'] * 0.2 + 1.
f(arr)
出现以下错误:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(Record(position[type=nestedarray(float32, (2,));offset=0],velocity[type=nestedarray(float32, (2,));offset=8];16;False), Literal[str](position), array(float64, 1d, C))
There are 16 candidate implementations:
- Of which 16 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(Record(position[type=nestedarray(float32, (2,));offset=0],velocity[type=nestedarray(float32, (2,));offset=8];16;False), unicode_type, array(float64, 1d, C))':
No match.
During: typing of staticsetitem at /tmp/ipykernel_21235/2952285515.py (13)
File "../../../../tmp/ipykernel_21235/2952285515.py", line 13:
<source missing, REPL/exec in use?>
对于如何补救后一个问题有什么想法吗?我想使用更高维度的数据类型。
您可以尝试使用
[:]
来设置数组的值:
import numpy as np
from numba import njit
Particle = np.dtype([("position", "f4", (2,)), ("velocity", "f4", (2,))])
arr = np.zeros(2, dtype=Particle)
@njit
def f(x):
pos_0 = x[0]["position"]
pos_0[:] = x[1]["position"] + x[1]["velocity"] * 0.2 + 1.0
#x[0]["position"][:] = ... works too
f(arr)
print(arr)
打印:
[([1., 1.], [0., 0.]) ([0., 0.], [0., 0.])]