Numba:将 numpy 数组转换为可哈希对象

问题描述 投票:0回答:1

在 numba

.jit(nopython=True)
函数中,我正在计算数千个 numpy 数组(一维,整数数据类型)并将它们附加到列表中。问题是某些数组看起来相等,但我不需要重复项。所以我需要一种有效的方法来检查新数组是否已存在于列表中。

在Python中可以这样完成:

import numpy as np
import numba as nb

# @nb.jit(nopython=True)
def foo(n):

    uniques = []
    uniques_set = set()

    for _ in range(n):

        arr = np.random.randint(0, 2, 2)
        arr_hashable = make_hashable(arr)

        if not arr_hashable in uniques_set:
            uniques_set.add(arr_hashable)
            uniques.append(arr)

    return uniques

我尝试了两种方法来解决这个问题:

  1. 将数组转换为元组并将元组放入集合中。

    def make_hashable(arr):
        return tuple(arr)
    

    但不幸的是,直接元组构造在 nopython 中不能以这种方式工作 模式。我也尝试过这种方式:

    def make_hashable(arr):
        res = ()
        for n in arr:
            res += (n,)
        return res
    

    以及我能想到的其他类似的解决方法,但它们都失败了 带有 TypeError 的 nopython 模式。

  2. 将数组转换为字符串并将其放入集合中。

    def make_hashable(arr):
        return arr.tostring()
    

    还尝试了所有可能的方法将数组转换为字符串,但看起来像 numba 暂时不支持字符串转换

也许有不同的方法来(有效地)检查数组是否已存在于列表中?我的 numba 版本是 0.44。非常感谢。

python numba
1个回答
0
投票

我有 numba 0.58,但我知道解决你的问题的唯一方法仍然是使用 回调到对象模式来散列数组。像这样:

import numpy as np
import numba as nb

def make_hashable(arr):
    return hash(arr.tobytes())

@nb.jit(nopython=True)
def foo(n):
    uniques = []
    uniques_set = set()
    for _ in range(n):
        arr = np.random.randint(0, 2, 2)
        with nb.objmode(arr_hashable='intp'):
            arr_hashable = make_hashable(arr)

        if arr_hashable not in uniques_set:
            uniques_set.add(arr_hashable)
            uniques.append(arr)
    return uniques

foo(100)
# => [array([0, 0]), array([0, 1]), array([1, 1]), array([1, 0])]
© www.soinside.com 2019 - 2024. All rights reserved.