查找实数数组中最近元素的最快方法

问题描述 投票:2回答:3

对于每个元素的给定实数数组,找到小于当前元素的元素数量不超过0.5并写入新数组。

例如:

原始阵列:

[0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]

结果数组:

[0,   0,   1,   2,    3,   0,   1]

解决这个问题的算法和方法是什么?

重要的是,仅在负方向上选择点的邻域,这使得不可能使用KdtreeBalltree算法。

我的所有问题都是here,我尝试使用代码。

python search tree binary-search-tree kdtree
3个回答
0
投票

虽然下面的方法使用简单的逻辑并且易于编写,但它很慢。我们可以通过使用装饰的Numba功能加快速度。这将加速简单的循环任务,使其接近汇编语言速度。

pip install numba安装Numba。

from numba import jit
import numpy as np

# Create a numpy array of length 10000 with float values between 0 and 10
random_values = np.random.uniform(0.0,10.0,size=(100*100,))

@jit(nopython=True, nogil=True)
def find_nearest(input):
  result = []
  for e in input:
    counter = 0
    for j in input:
      if j >= (e-0.5) and j < e:
        counter += 1
    result.append(counter)
  return result

result = find_nearest(random_values)

请注意,为测试用例返回了预期结果:

test = [0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]
result = find_nearest(test)
print result

返回:

[0, 0, 1, 2, 3, 0, 1]

0
投票

这将解决您的具体任务。

def find_nearest_element(original_array):
    result_array = []
    for e in original_array:
        result_array.append(len(original_array[(e-0.5 < original_array) & (e > original_array)]))
    return result_array

original_array = np.array([0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7])
print(find_nearest_element(original_array))

输出:

[0, 0, 1, 2, 3, 0, 1]

编辑:使用掩码比使用numba的较小阵列(ca.len 10000)的版本快得多。对于更大的阵列,使用Numba的版本更快。所以这取决于你想要进步的数组大小。

一些运行时比较(以秒为单位):

For smaller arrays(size=250):
Using Numba 0.2569999694824219
Using mask 0.0350041389465332
For bigger arrays(size=40000):
Using Numba 1.4619991779327393
Using mask 4.280000686645508

我的设备上的收支平衡点大约是10000(大约需要0.33秒)。


0
投票

对于有序数组,这个问题很容易解决。您必须向后搜索并计算所有大于实际数字半径的数字。如果不再满足该条件,则可以突破内循环(这节省了大量时间)。

import numpy as np
from scipy import spatial
import numba as nb

@nb.njit(parallel=True)
def get_counts_2(Points_sorted,ind,r):
  counts=np.zeros(Points_sorted.shape[0],dtype=np.int64)
  for i in nb.prange(0,Points_sorted.shape[0]):
    count=0
    for j in range(i-1,0,-1):
      if (Points_sorted[i]-r<Points_sorted[j]):
        count+=1
      else:
        break
    counts[ind[i]]=count
  return counts

计时

r=0.001
Points=np.random.rand(1_000_000)

t1=time.time()
ind=np.argsort(Points)
Points_sorted=Points[ind]
counts=get_counts_2(Points_sorted,ind,r)
print(time.time()-t1)
#0.29s
© www.soinside.com 2019 - 2024. All rights reserved.