记录类型为字符串的numba jitclass

问题描述 投票:0回答:1

v3 变量是字符串值。我无法使用下面的代码运行,这会出现错误。

import numpy as np
import pandas as pd
from numba.experimental import jitclass
from numba import types
import os

os.environ['NUMBA_VERBOSE'] = '1'

# ----- BEGINNING OF THE MODIFIED PART ----- #
recordType = types.Record([
    ('v', {'type': types.int64, 'offset': 0, 'alignment': None, 'title': None}),
    ('v2', {'type': types.float64, 'offset': 8, 'alignment': None, 'title': None}),
    ('v3', {'type': types.bytes, 'offset': 16, 'alignment': None, 'title': None})
], 32, False)
spec = [
    ('data', types.Array(recordType, 1, 'C', False))
]
# ----- END OF THE MODIFIED PART ----- #

@jitclass(spec)
class Test:
    def __init__(self, data):
        self.data = data

    def loop(self):
        v = self.data['v']
        v2 = self.data['v2']
        v3 = self.data['v3']
        print("Inside loop:")
        print("v:", v)
        print("v2:", v2)
        print("v3:", v3)

# Create a dictionary with the data
data = {'v': [1, 2, 3], 'v2': [1.0, 2.0, 3.0], 'v3': ['a', 'b', 'c']}

# Create the DataFrame
df = pd.DataFrame(data)

# Define the structured array dtype
dtype = np.dtype([
    ('v', np.int64),
    ('v2', np.float64),
    ('v3', 'S10')  # Byte string with maximum length of 10 characters
])

print(df.to_records(index=False))

# Create the structured array
data_array = np.array(list(df.to_records(index=False)), dtype=dtype)

print("Original data array:")
print(data_array)

# Create an instance of the Test class
test = Test(data_array)
test.loop()

错误:

/home/totaljj/miniconda3/bin/conda run -n bt --no-capture-output python /home/totaljj/bt_lite_strategies/test/test_units/test_numba_obj.py 
Traceback (most recent call last):
  File "/home/totaljj/bt_lite_strategies/test/test_units/test_numba_obj.py", line 13, in <module>
    ('v3', {'type': types.bytes, 'offset': 16, 'alignment': None, 'title': None})
AttributeError: module 'numba.core.types' has no attribute 'bytes'
ERROR conda.cli.main_run:execute(124): `conda run python /home/totaljj/bt_lite_strategies/test/test_units/test_numba_obj.py` failed. (See above for error)

Process finished with exit code 1,
python numba jit
1个回答
0
投票

Numba 57.1、58.1 和 59.1 都没有

types.bytes
类型。 在这里,您应该根据您的情况使用类型 types.CharSeq(10)
(对于 
S10
 Numpy 类型)。此外,
最终大小是错误的:它应该是 26 而不是 32,因为有 10 个字符,另外两个值各占 8 个字节(没有对齐)。

这是修改后的代码:

import numpy as np import pandas as pd from numba.experimental import jitclass from numba import types import os os.environ['NUMBA_VERBOSE'] = '1' # ----- BEGINNING OF THE MODIFIED PART ----- # recordType = types.Record([ ('v', {'type': types.int64, 'offset': 0, 'alignment': None, 'title': None}), ('v2', {'type': types.float64, 'offset': 8, 'alignment': None, 'title': None}), ('v3', {'type': types.CharSeq(10), 'offset': 16, 'alignment': None, 'title': None}) ], 26, False) spec = [ ('data', types.Array(recordType, 1, 'C', False)) ] # ----- END OF THE MODIFIED PART ----- # @jitclass(spec) class Test: def __init__(self, data): self.data = data def loop(self): v = self.data['v'] v2 = self.data['v2'] v3 = self.data['v3'] print("Inside loop:") print("v:", v) print("v2:", v2) print("v3:", v3) # Create a dictionary with the data data = {'v': [1, 2, 3], 'v2': [1.0, 2.0, 3.0], 'v3': ['a', 'b', 'c']} # Create the DataFrame df = pd.DataFrame(data) # Define the structured array dtype dtype = np.dtype([ ('v', np.int64), ('v2', np.float64), ('v3', 'S10') # Byte string with maximum length of 10 characters ]) print(df.to_records(index=False)) # Create the structured array data_array = np.array(list(df.to_records(index=False)), dtype=dtype) print("Original data array:") print(data_array) # Create an instance of the Test class test = Test(data_array) test.loop()
    
© www.soinside.com 2019 - 2024. All rights reserved.