我想在多维空间(最多 7 维)中找到给定点周围的直接邻居。
有关空间的重要事实:
(生成维度之间间距不均匀的网格的示例代码)
x_values = np.linspace(-0.3, 0.3, 5)
y_values = np.linspace(-0.3, 0.3, 5)
z_values = np.linspace(1, 6, 6) # unqual spacing (large spacing in z-direction)
MWE:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.neighbors import KDTree
import numpy as np
# Define ranges for X, Y, and Z values
x_values = np.linspace(-0.3, 0.3, 5)
y_values = np.linspace(-0.3, 0.3, 5)
z_values = np.linspace(1, 6, 6) # unqual spacing (large spacing in z-direction)
# z_values = np.linspace(-0.3, 0.3, 5) # equal spacing case
# Create meshgrid to generate combinations of X, Y, and Z values
X, Y, Z = np.meshgrid(x_values, y_values, z_values)
# Reshape the meshgrid arrays to create a single array of all combinations
points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
# Create a KDTree object with the sample points
kdtree = KDTree(points, leaf_size=30, metric='euclidean')
# Query point for which nearest neighbors will be found
# query_point = np.array([[0, 0, 0]]) # test query point for equal spacing
query_point = np.array([[0, 0, 2]]) # test query point for unequal spacing
# Find the indices of the nearest neighbors and their distances
distances, indices = kdtree.query(query_point, k=27)
# Plot all points in 3D
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(points[:, 0], points[:, 1], points[:, 2], color='blue', label='All Points')
# Plot the query point in 3D
ax.scatter(query_point[:, 0], query_point[:, 1], query_point[:, 2], color='red', label='Query Point')
# Plot the nearest neighbors in 3D
nearest_neighbors = points[indices[0]] # Get nearest neighbors using indices
ax.scatter(nearest_neighbors[:, 0], nearest_neighbors[:, 1], nearest_neighbors[:, 2], color='green', label='Nearest Neighbors')
# Connect the query point with its nearest neighbors in 3D
for neighbor in nearest_neighbors:
ax.plot([query_point[0, 0], neighbor[0]], [query_point[0, 1], neighbor[1]], [query_point[0, 2], neighbor[2]], color='gray', linestyle='--')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('KD-Tree Nearest Neighbors in 3D')
ax.legend()
plt.show()
print()
上述代码的结果:
所需结果: 应从每个维度中选择直接邻居,无论其实际距离如何。
然后你必须“返回”到原始空间以获得原始距离/绘图。 (返回逆变换或简单地获取索引并使用原始数据)。
这是代码:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.neighbors import KDTree
import numpy as np
# Define ranges for X, Y, and Z values
x_values = np.linspace(-0.3, 0.3, 5)
y_values = np.linspace(-0.3, 0.3, 5)
z_values = np.linspace(1, 6, 6) # unqual spacing (large spacing in z-direction)
# z_values = np.linspace(-0.3, 0.3, 5) # equal spacing case
# Create meshgrid to generate combinations of X, Y, and Z values
X, Y, Z = np.meshgrid(x_values, y_values, z_values)
# Reshape the meshgrid arrays to create a single array of all combinations
points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
# Create a KDTree object with the rescaled points
points_std = points.std(axis=0)
points_rescaled = points / points_std
kdtree = KDTree(points_rescaled, leaf_size=30, metric='euclidean')
# Query point for which nearest neighbors will be found
# query_point = np.array([[0, 0, 0]]) # test query point for equal spacing
query_point = np.array([[0, 0, 2]]) # test query point for unequal spacing
# Find the indices of the nearest neighbors and their distances
distances, indices = kdtree.query(query_point / points_std, k=27)
# Plot all points in 3D
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(points[:, 0], points[:, 1], points[:, 2], color='blue', label='All Points')
# Plot the query point in 3D
ax.scatter(query_point[:, 0], query_point[:, 1], query_point[:, 2], color='red', label='Query Point')
# Plot the nearest neighbors in 3D
nearest_neighbors = points[indices[0]] # Get nearest neighbors using indices
ax.scatter(nearest_neighbors[:, 0], nearest_neighbors[:, 1], nearest_neighbors[:, 2], color='green', label='Nearest Neighbors')
# Connect the query point with its nearest neighbors in 3D
for neighbor in nearest_neighbors:
ax.plot([query_point[0, 0], neighbor[0]], [query_point[0, 1], neighbor[1]], [query_point[0, 2], neighbor[2]], color='gray', linestyle='--')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('KD-Tree Nearest Neighbors in 3D')
ax.legend()
plt.show()
print()
原空间的结果: