对于每个元素的给定实数数组,找到小于当前元素的元素数量不超过0.5
并写入新数组。
例如:
原始阵列:
[0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]
结果数组:
[0, 0, 1, 2, 3, 0, 1]
解决这个问题的算法和方法是什么?
重要的是,仅在负方向上选择点的邻域,这使得不可能使用Kdtree
或Balltree
算法。
我的所有问题都是here,我尝试使用代码。
虽然下面的方法使用简单的逻辑并且易于编写,但它很慢。我们可以通过使用装饰的Numba功能加快速度。这将加速简单的循环任务,使其接近汇编语言速度。
用pip install numba
安装Numba。
from numba import jit
import numpy as np
# Create a numpy array of length 10000 with float values between 0 and 10
random_values = np.random.uniform(0.0,10.0,size=(100*100,))
@jit(nopython=True, nogil=True)
def find_nearest(input):
result = []
for e in input:
counter = 0
for j in input:
if j >= (e-0.5) and j < e:
counter += 1
result.append(counter)
return result
result = find_nearest(random_values)
请注意,为测试用例返回了预期结果:
test = [0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]
result = find_nearest(test)
print result
返回:
[0, 0, 1, 2, 3, 0, 1]
这将解决您的具体任务。
def find_nearest_element(original_array):
result_array = []
for e in original_array:
result_array.append(len(original_array[(e-0.5 < original_array) & (e > original_array)]))
return result_array
original_array = np.array([0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7])
print(find_nearest_element(original_array))
输出:
[0, 0, 1, 2, 3, 0, 1]
编辑:使用掩码比使用numba的较小阵列(ca.len 10000)的版本快得多。对于更大的阵列,使用Numba的版本更快。所以这取决于你想要进步的数组大小。
一些运行时比较(以秒为单位):
For smaller arrays(size=250):
Using Numba 0.2569999694824219
Using mask 0.0350041389465332
For bigger arrays(size=40000):
Using Numba 1.4619991779327393
Using mask 4.280000686645508
我的设备上的收支平衡点大约是10000(大约需要0.33秒)。
对于有序数组,这个问题很容易解决。您必须向后搜索并计算所有大于实际数字半径的数字。如果不再满足该条件,则可以突破内循环(这节省了大量时间)。
例
import numpy as np
from scipy import spatial
import numba as nb
@nb.njit(parallel=True)
def get_counts_2(Points_sorted,ind,r):
counts=np.zeros(Points_sorted.shape[0],dtype=np.int64)
for i in nb.prange(0,Points_sorted.shape[0]):
count=0
for j in range(i-1,0,-1):
if (Points_sorted[i]-r<Points_sorted[j]):
count+=1
else:
break
counts[ind[i]]=count
return counts
计时
r=0.001
Points=np.random.rand(1_000_000)
t1=time.time()
ind=np.argsort(Points)
Points_sorted=Points[ind]
counts=get_counts_2(Points_sorted,ind,r)
print(time.time()-t1)
#0.29s