我在下面有一个函数,它接收一个浮点数数组和一个离散整数数组。对于所有的浮点数,我希望它们四舍五入到列表中最接近的整数。
以下函数完美运行,其中 sHatV 是一个包含 10,000 个浮点数的数组,而 possible_locations 是一个包含 5 个整数的数组:
binnedV = [min(possible_locations, key=lambda x:abs(x-bv)) for bv in sHatV]
因为这个函数将被调用数千次,所以我尝试使用
@numba.njit
装饰器来最小化计算时间。
我考虑过在我的“numbafied”函数中使用
np.digitize
,但它会将超出范围的值四舍五入为零。我希望所有内容都被合并到可能位置的值之一。
总的来说,我需要编写一个 numba 兼容函数,它获取第一个长度为 N 的数组中的每个值,在数组 2 中找到最接近它的值,并返回最接近的值,最终形成一个长度为 N 的数组,其中包含分箱值.
感谢任何帮助!
这是一个运行速度更快的版本,并且可能更“可计算”,因为它使用 numpy 函数而不是列表理解的隐式 for 循环:
import numpy as np
sHatV = [0.33, 4.18, 2.69]
possible_locations = np.array([0, 1, 2, 3, 4, 5])
diff_matrix = np.subtract.outer(sHatV, possible_locations)
idx = np.abs(diff_matrix).argmin(axis=1)
result = possible_locations[idx]
print(result)
# output: [0 4 3]
这里的思路是计算
sHatv
和possible_locations
之间的差异矩阵。在这个特定的例子中,矩阵是:
array([[ 0.33, -0.67, -1.67, -2.67, -3.67, -4.67],
[ 4.18, 3.18, 2.18, 1.18, 0.18, -0.82],
[ 2.69, 1.69, 0.69, -0.31, -1.31, -2.31]])
然后,使用
np.abs( ... ).argmin(axis=1)
,我们找到绝对差异最小的每一行的索引。如果我们用这些索引索引原始的possible_locations
数组,我们就得到了答案。
比较运行时间:
使用列表理解
def f(possible_locations, sHatV):
return [min(possible_locations, key=lambda x:abs(x-bv)) for bv in sHatV]
def test_f():
possible_locations = np.array([0, 1, 2, 3, 4, 5])
sHatV = np.random.uniform(0.1, 4.9, size=10_000)
f(possible_locations, sHatV)
%timeit test_f()
# 187 ms ± 7.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
使用差异矩阵
def g(possible_locations, sHatV):
return possible_locations[np.abs(np.subtract.outer(sHatV, bins)).argmin(axis=1)]
def test_g():
possible_locations = np.array([0, 1, 2, 3, 4, 5])
sHatV = np.random.uniform(0.1, 4.9, size=10_000)
g(possible_locations, sHatV)
%timeit test_g()
# 556 µs ± 24.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
np.searchSorted()
函数。 np.digitize()
本身是根据np.searchSorted()
实现的。例如,
import numpy as np
offset = 1e-8
indices = np.searchsorted(possible_locations, sHatV - offset)
return possible_locations[np.clip(indices, 0, len(int) - 1)]
我建议为此坚持使用 numpy。
digitize
功能接近您的需要,但需要进行一些修改:
If values in `x` are beyond the bounds of `bins`, 0 or ``len(bins)`` is returned as appropriate.
举个例子:
import numpy as np
sHatV = np.array([-99, 1.4999, 1.5, 3.1, 3.9, 99.5, 1000])
bins = np.arange(0,101)
def custom_round(arr, bins):
bin_centers = (bins[:-1] + bins[1:])/2
idx = np.digitize(arr, bin_centers)
result = bins[idx]
return result
assert np.all(custom_round(sHatV, bins) == np.array([0, 1, 2, 3, 4, 100, 100]))
现在我最喜欢的部分是:numpy 在这方面有多快?我不会做缩放,我们只会选择大数组:
sHatV = 10009*np.random.random(int(1e6))
bins = np.arange(10000)
%timeit custom_round(sHatV, bins)
# on a laptop: 100 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)