在 Python 中从大型数据集中高效查找最近位置

我正在开发一个项目,需要为另一个包含 500,000 个地址的大型数据集中的每个条目找到三个最近的地理位置(从包含 1,000 个条目的数据集中)。每个数据集包括纬度和经度坐标。我一直在使用 Python 与 geopy 和 pandas,但由于数据集的大小,我不确定什么是最好的方法。


import pandas as pd
from geopy.distance import geodesic

# Sample data setup
data = {
    "Zip Code": ["10115", "20095", "50667", "80331", "70173", "40210", "41460", "45127", "47051", "40474"],
    "City": ["Berlin", "Hamburg", "Cologne", "Munich", "Stuttgart", "Düsseldorf", "Neuss", "Essen", "Duisburg", "Düsseldorf Airport"],
    "Latitude": [52.5323, 53.5503, 50.9367, 48.1372, 48.7833, 51.2217, 51.1981, 51.4556, 51.4344, 51.2895],
    "Longitude": [13.3846, 9.9930, 6.9540, 11.5755, 9.1815, 6.7763, 6.6913, 7.0116, 6.7623, 6.7668]
df = pd.DataFrame(data)

data_df2 = {
    "Address": ["my_location"],
    "ZipCode": ["40468"],
    "Latitude": [51.28472232436951],
    "Longitude": [6.7865234073914005]
df2 = pd.DataFrame(data_df2)

# Distance calculation
def calculate_distance(lat1, lon1, lat2, lon2):
    return geodesic((lat1, lon1), (lat2, lon2)).kilometers

def find_nearest_spots(address_lat, address_lon):
    distances = df.apply(lambda row: calculate_distance(address_lat, address_lon, row['Latitude'], row['Longitude']), axis=1)
    nearest_indices = distances.nsmallest(3).index
    nearest_spots = df.loc[nearest_indices]
    return pd.Series({
        'Nearest_1_Spot': nearest_spots.iloc[0]['City'],
        'Nearest_1_Dist': distances.iloc[nearest_indices[0]],
        'Nearest_2_Spot': nearest_spots.iloc[1]['City'],
        'Nearest_2_Dist': distances.iloc[nearest_indices[1]],
        'Nearest_3_Spot': nearest_spots.iloc[2]['City'],
        'Nearest_3_Dist': distances.iloc[nearest_indices[2]]

df2[['Nearest_1_Spot', 'Nearest_1_Dist', 'Nearest_2_Spot', 'Nearest_2_Dist', 'Nearest_3_Spot', 'Nearest_3_Dist']] = df2.apply(
    lambda row: find_nearest_spots(row['Latitude'], row['Longitude']), axis=1)

如果需要,我还可以利用免费的 GPU 库(如果有)


正如评论中提到的,我建议使用 KDtree 来识别最近邻居。

这是一个如何从 500k 点列表中查询 1000 个点的 3 个最近邻点的简单示例:

from scipy.spatial import KDTree
import numpy as np

# setup a search-tree for locations
lons, lats = np.random.normal(0, 50, 500000), np.random.normal(0, 50, 500000)
tree = KDTree(np.column_stack([lons, lats]))

# query points
q_lons, q_lats = np.random.normal(0, 50, 1000), np.random.normal(0, 50, 1000)
dist, idx = tree.query(np.column_stack([q_lons, q_lats]), k=3, distance_upper_bound=1)

# get mask for points without valid neighbours
mask = idx == tree.n

# extract neighbour coordinates 
# (use "warp" because points representing invalid neighbours are indicated with a too large index value)
neighbour_lons = np.take(lons, idx, mode="warp")
neighbour_lons[mask] = np.nan
neighbour_lats = np.take(lats, idx, mode="warp")
neighbour_lats[mask] = np.nan


import matplotlib.pyplot as plt

# to color points
c = q_lons**2 + q_lats**2
c = plt.cm.tab20(c/c.max())

f, ax = plt.subplots()
ax.scatter(lons, lats, c="k", s=5)
ax.scatter(q_lons, q_lats, marker="o", s=70, c=c)
ax.scatter(neighbour_lons.ravel(), neighbour_lats.ravel(), 
           marker="x", s=70, c=c.repeat(3, axis=0))

enter image description here

