使用 Numpy 和 Numba 将一组值分箱到离散集中最接近的值

问题描述 投票:0回答:3

我在下面有一个函数,它接收一个浮点数数组和一个离散整数数组。对于所有的浮点数,我希望它们四舍五入到列表中最接近的整数。

以下函数完美运行,其中 sHatV 是一个包含 10,000 个浮点数的数组,而 possible_locations 是一个包含 5 个整数的数组:

binnedV = [min(possible_locations, key=lambda x:abs(x-bv)) for bv in sHatV]

因为这个函数将被调用数千次,所以我尝试使用

@numba.njit
装饰器来最小化计算时间。

我考虑过在我的“numbafied”函数中使用

np.digitize
,但它会将超出范围的值四舍五入为零。我希望所有内容都被合并到可能位置的值之一。

总的来说,我需要编写一个 numba 兼容函数,它获取第一个长度为 N 的数组中的每个值,在数组 2 中找到最接近它的值,并返回最接近的值,最终形成一个长度为 N 的数组,其中包含分箱值.

感谢任何帮助!

python numpy numba
3个回答
1
投票

这是一个运行速度更快的版本,并且可能更“可计算”,因为它使用 numpy 函数而不是列表理解的隐式 for 循环:

import numpy as np

sHatV = [0.33, 4.18, 2.69]
possible_locations = np.array([0, 1, 2, 3, 4, 5])

diff_matrix = np.subtract.outer(sHatV, possible_locations)
idx = np.abs(diff_matrix).argmin(axis=1)
result = possible_locations[idx]

print(result)
# output: [0 4 3]

这里的思路是计算

sHatv
possible_locations
之间的差异矩阵。在这个特定的例子中,矩阵是:

array([[ 0.33, -0.67, -1.67, -2.67, -3.67, -4.67],
       [ 4.18,  3.18,  2.18,  1.18,  0.18, -0.82],
       [ 2.69,  1.69,  0.69, -0.31, -1.31, -2.31]])

然后,使用

np.abs( ... ).argmin(axis=1)
,我们找到绝对差异最小的每一行的索引。如果我们用这些索引索引原始的
possible_locations
数组,我们就得到了答案。

比较运行时间:

使用列表理解

def f(possible_locations, sHatV):
    return [min(possible_locations, key=lambda x:abs(x-bv)) for bv in sHatV]


def test_f():
    possible_locations = np.array([0, 1, 2, 3, 4, 5])
    sHatV = np.random.uniform(0.1, 4.9, size=10_000)
    f(possible_locations, sHatV)


%timeit test_f()
# 187 ms ± 7.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

使用差异矩阵

def g(possible_locations, sHatV):
    return possible_locations[np.abs(np.subtract.outer(sHatV, bins)).argmin(axis=1)]


def test_g():
    possible_locations = np.array([0, 1, 2, 3, 4, 5])
    sHatV = np.random.uniform(0.1, 4.9, size=10_000)
    g(possible_locations, sHatV)

%timeit test_g()
# 556 µs ± 24.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

1
投票

为此,您可以使用 numpy 的

np.searchSorted()
函数。
np.digitize()
本身是根据
np.searchSorted()
实现的。例如,

import numpy as np

offset = 1e-8
indices = np.searchsorted(possible_locations, sHatV - offset)
return possible_locations[np.clip(indices, 0, len(int) - 1)]

1
投票

我建议为此坚持使用 numpy。

digitize
功能接近您的需要,但需要进行一些修改:

  • 实现舍入逻辑而不是 floor/ceil
  • 解决端点问题。文档说:
    If values in `x` are beyond the bounds of `bins`, 0 or ``len(bins)`` is returned as appropriate.

举个例子:

import numpy as np
sHatV = np.array([-99, 1.4999, 1.5, 3.1, 3.9, 99.5, 1000])
bins = np.arange(0,101)

def custom_round(arr, bins):
    bin_centers = (bins[:-1] + bins[1:])/2 
    idx = np.digitize(arr, bin_centers)
    result = bins[idx]
    return result

assert np.all(custom_round(sHatV, bins) == np.array([0, 1, 2, 3, 4, 100, 100]))

现在我最喜欢的部分是:numpy 在这方面有多快?我不会做缩放,我们只会选择大数组:

sHatV = 10009*np.random.random(int(1e6))
bins = np.arange(10000)

%timeit custom_round(sHatV, bins)
# on a laptop: 100 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.