高维结构化 numpy 数据类型上的 numba 类型错误

问题描述 投票:0回答:1

以下代码可以正确编译并执行:

import numpy as np
from numba import njit

Particle = np.dtype([ ('position', 'f4'), ('velocity', 'f4')])

arr = np.zeros(2, dtype=Particle)

@njit
def f(x):
    x[0]['position'] = x[1]['position'] + x[1]['velocity'] * 0.2 + 1.
    
f(arr)

但是,使数据类型具有更高的维度会导致此代码在编译时失败(但无需

@njit
也能工作):

import numpy as np
from numba import njit

Particle = np.dtype([
            ('position', 'f4', (2,)),
            ('velocity', 'f4', (2,))
          ])

arr = np.zeros(2, dtype=Particle)

@njit
def f(x):
    x[0]['position'] = x[1]['position'] + x[1]['velocity'] * 0.2 + 1.
    
f(arr)

出现以下错误:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(Record(position[type=nestedarray(float32, (2,));offset=0],velocity[type=nestedarray(float32, (2,));offset=8];16;False), Literal[str](position), array(float64, 1d, C))
 
There are 16 candidate implementations:
    - Of which 16 did not match due to:
    Overload of function 'setitem': File: <numerous>: Line N/A.
      With argument(s): '(Record(position[type=nestedarray(float32, (2,));offset=0],velocity[type=nestedarray(float32, (2,));offset=8];16;False), unicode_type, array(float64, 1d, C))':
     No match.

During: typing of staticsetitem at /tmp/ipykernel_21235/2952285515.py (13)

File "../../../../tmp/ipykernel_21235/2952285515.py", line 13:
<source missing, REPL/exec in use?>

对于如何补救后一个问题有什么想法吗?我想使用更高维度的数据类型。

python numba
1个回答
0
投票

您可以尝试使用

[:]
来设置数组的值:

import numpy as np
from numba import njit

Particle = np.dtype([("position", "f4", (2,)), ("velocity", "f4", (2,))])

arr = np.zeros(2, dtype=Particle)


@njit
def f(x):
    pos_0 = x[0]["position"]
    pos_0[:] = x[1]["position"] + x[1]["velocity"] * 0.2 + 1.0

    #x[0]["position"][:] = ... works too

f(arr)
print(arr)

打印:

[([1., 1.], [0., 0.]) ([0., 0.], [0., 0.])]
© www.soinside.com 2019 - 2024. All rights reserved.