使用scipy.spatial.cKDTree或使用两个数据集之间的最近邻点具有随机2d点的queryBallPoint的问题

问题描述 投票:0回答:2

我有P_0点随机分布在2d框中。然后,我将它们分为两个组S和I。如果S的某些点离I太近,则将它们从S组中删除并添加到I组中。我面临的问题是,有时它们没有正确地从S中删除,但已正确地添加到I中。因此,总点数会错误地增加。这是代码:

from scipy.spatial import cKDTree
import numpy as np
import matplotlib.pyplot as plt
P_0 = 100 # initial susceptible population
# dimensions of box
Lx = 5.0
Ly = 5.0
# generate P_0 random points inside box
X = np.random.uniform(0, Lx, P_0)
Y = np.random.uniform(0, Ly, P_0)
pts = np.column_stack((X, Y)) # array of 2d points
S = np.arange(10, P_0) # indices of the susceptible
I = np.arange(10) # indices of the infected
# Divide points into infected and susceptible groups
r_I = pts[I]
r_S = pts[S]
tree = cKDTree(r_S)
# idx represents the indices to points in r_S which are closer than r to
# points in r_I
idx = tree.query_ball_point(r_I, r=0.4) 
idx = np.hstack(idx) # flatten the lists into one numpy array
idx = idx.astype(int) # Make sure idx indices have int type
print idx

# plot points
plt.figure()
plt.plot (r_S[:, 0], r_S[:, 1], 'bo') # plot all r_S points
plt.plot (r_S[idx, 0], r_S[idx, 1], 'ko') # color those points nearest to r_I
plt.plot (r_I[:, 0], r_I[:, 1], 'ro') # identify the r_I points
print len(S), len(I), len(S)+len(I)
I= np.append(I, S[idx]) # add the closest points to I
S = np.delete(S, idx) # delete the closest points from S

# points in r_I
idx = tree.query_ball_point(r_I, r=0.4) 
idx = np.hstack(idx) # flatten the lists into one numpy array
idx = idx.astype(int) # Make sure idx indices have int type
print idx

# plot points
plt.figure()
plt.plot (r_S[:, 0], r_S[:, 1], 'bo') # plot all r_S points
plt.plot (r_S[idx, 0], r_S[idx, 1], 'ko') # color those points nearest to r_I
plt.plot (r_I[:, 0], r_I[:, 1], 'ro') # identify the r_I points
print len(S), len(I), len(S)+len(I)
I= np.append(I, S[idx]) # add the closest points to I
S = np.delete(S, idx) # delete the closest points from S

plt.figure('S group')
plt.plot (pts[S, 0], pts[S, 1], 'bo') # plot the updated r_S points

plt.figure('I group')
plt.plot (pts[I, 0], pts[I, 1], 'ro') # plot the updated r_I points
print len(S), len(I), len(S)+len(I), len(idx)
plt.show()

所以,我不知道为什么r_S中的所有点都不都比r更近,有时不从S中删除。一个人可能必须运行几次代码才能出现错误,或者例如将P_0增加到1000或增加r的值。 idx以及我使用numpy delete的方式可能有问题。

python numpy scipy
2个回答
0
投票

您可以通过交换操作(仅在成功删除后先执行删除操作,然后再进行添加操作来进行移动,然后在单独的变量中测试删除操作,然后再进行测试,来仔细检查您的假设。

将删除后的结果大小与组的原始大小进行比较。如果大小匹配,则不会发生删除(无论出于何种原因),这表示不希望在另一侧添加任何内容。

然后,您可以在大写字母上打印组,并查看参与其中的索引以使隧道更加清晰。


0
投票

正如我刚才评论的,我只需要消除idx中的重复项。我加了行

idx = np.unique(idx)

在idx之下= idx.astype(int)

© www.soinside.com 2019 - 2024. All rights reserved.