最近邻搜索和VPTrees的性能问题

问题描述 投票:0回答:3

我读过这个Q / A - knn with big sparse matrices in python,我也有类似的问题。我有一个稀疏的雷达数据阵列 - 大小为125930,经度和纬度的形状相同。只有5%的数据不是NULL。其余都是NULL。

数据在球体上可用,因此我使用VPTree和大圆距来计算距离。网格间距是不规则的,我想将这些数据插值到球体上的规则网格中,其距离在纬度和经度方向上,网格间距为0.05度。粗网格中两个纬度之间的间距为0.01,两个经度之间的间距为0.09。所以我按照以下方式创建我的网格网格,并且我有以下数量的网格点 - 总共12960000,基于不规则网格的纬度和经度的最大值。

latGrid = np.arange(minLat,maxLat,0.05)
lonGrid = np.arange(minLo,maxLo,0.05)


gridLon,gridLat = np.meshgrid(lonGrid,latGrid)
grid_points = np.c_[gridLon.ravel(),gridLat.ravel()]

radar_data = radar_element[np.nonzero(radar_element)]
lat_surface = lat[np.nonzero(radar_element)]
lon_surface = lon[np.nonzero(radar_element)]

points = np.c_[lon_surface,lat_surface]
if points.size > 0:
   tree = vptree.VPTree(points,greatCircleDistance)
    for grid_point in (grid_points):
        indices = tree.get_all_in_range(grid_point,4.3)
        args.append(indices)

问题是查询

get_all_in_range

对于上述数据的每次传递,目前需要12分钟才能运行,我总共有175次传递,总时间为35小时,这是不可接受的。有没有办法减少网格点的数量(基于某些相似性)由于返回的大部分索引为空,因此发送到查询?我也使用过Scikit-learn的BallTree,性能甚至比这个更差。我不确定FLANN是否适用于我的问题。

python performance interpolation sparse-matrix nearest-neighbor
3个回答
1
投票

我只想转换为3D坐标并使用欧几里德距离。

你可以使用像Annoy这样的东西(披露:我是作者)

我建造的东西的例子:https://github.com/erikbern/ping/blob/master/plot.py


1
投票

我首先将您的雷达观测值作为纬度/经度进行空间索引。为了Python,让我们使用R-Tree。我会遵循这个概念:

http://toblerity.org/rtree/tutorial.html#using-rtree-as-a-cheapo-spatial-database

加载您的雷达观测:

for id, (y, x, m) in enumerate(observations):
    index.insert(id=id, bounds=(x, y, x, y), obj=(y,x,m))

然后,对于你想要的大圆距离,我会计算一个“安全”欧几里德距离来过滤出候选点。

您可以在R-Tree中查询输出网格点(x,y)附近的候选点:

candidates  =  idx.intersection((x - safe_distance, y - safe_distance, x + safe_distance, y+safe distance), objects=True)]

这将为您提供候选点列表[(y, x, m),...]

现在使用Great Circle计算过滤候选者。然后,您可以使用剩余的点对象进行插值。


1
投票

这是另一种在相反方向解决问题的策略。我认为这是一种更好的方法,原因有以下几点:

  • 雷达观测数据集是稀疏的,因此即使借助于空间索引,在每个输出点上运行计算似乎也是浪费的。
  • 输出网格具有规则间距,因此可以轻松计算。
  • 因此,查看每个观察点并计算附近的输出点并使用该信息来构建输出点列表以及它们接近的观察点将不那么重要。

观测数据的形式为(X,Y,M)(经度,纬度,测量值)。

输出是一个有规则间距的网格,就像每个.1度一样。

首先为接近观察的输出网格点创建一个字典:

output = {}

然后取一个观察点,找到大圆距离内附近的点。开始检查附近的输出点并按行/列向外迭代,直到找到GCD中所有可能的观察点。

这将为您提供X和Y的GCD内的网格点列表。类似于:

get_points(X,Y) ----> [[x1, y1], [x2,y2]....]

现在我们将翻转它。我们想要存储每个输出点和它附近的观察点列表。要将点存储在输出字典中,我们需要某种唯一键。 geohash(交错纬度和经度,并产生一个独特的字符串)是完美的。

对于每个输出点(xn,yn),计算geohash,并使用(xn,yn)向输出字典添加条目,并开始(或追加)观察列表:

key = Geohash.encode(y,x)
if key not in output:
    output[key] = { 'coords': [x, y], 'observations' = [[X,Y,M]] }
else:
    output[key][observations].append([X,Y,M])

我们存储原始的x,y而不是反转geohash并失去准确性。

当您运行所有观察结果后,您将获得需要计算的所有输出点的字典,每个点都有一个符合GCD要求的观察列表。

然后,您可以遍历这些点并计算输出数组索引和插值,并将其写入输出数组:

def get_indices(x,y):
    ''' converts a x,y to row,column in the output array '''
    ....
    return row, column

def get_weighted value(point):
    ''' takes a point entry and performs inverse distance weighting
        using the point's coordinates and list of observations '''
    ....
    return value

for point in output:
    row, column = get_indices(point['coords'])
    idw = get_weighted_value(point)
    outarray[column,row] = idw
© www.soinside.com 2019 - 2024. All rights reserved.