我有绘制在 XY 轴上的点的数据集,并且在每个数据集中,我提前知道这些点应形成的线数。我的目标是使用 RANSAC 等方法来检测这些线。在提供的示例中,我知道应该有 2 行。
这是此类数据集的直观表示:
我很难配置 RANSAC(或其他方法)来一致地检测正确的线数,特别是当线彼此靠近或具有不同的点密度时。
有没有办法指导 RANSAC(或建议另一种方法)来检测我预先知道的每个数据集的确切行数?
我尝试过的:
我的期望示例:
让我们创建一些数据集:
import numpy as np
from sklearn.cluster import HDBSCAN
from sklearn.linear_model import RANSACRegressor
np.random.seed(12345)
def dataset(n=3, m=10, slope=(2, 3), intercept=(4, 7), sigma=0.1):
x = np.linspace(0, 1, m)
a = np.linspace(*slope, n)
b = np.linspace(*intercept, n)
xs = []
ys = []
for i in range(n):
xs.extend(x)
ys.extend(a[i] * x + b[i] + sigma * np.random.normal(size=x.size))
X = np.stack([np.array(xs), np.array(ys)]).T
return X
n = 3
X = dataset(n=n)
第一次操作,我们需要对点进行聚类:
cluster = HDBSCAN(min_cluster_size=5, min_samples=3, cluster_selection_epsilon=0.15)
cluster.fit(X)
# array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 2, 2, 2, 2, 2, 2, 2, -1])
聚类并不完美,可能会根据您的数据集进行调整和/或在后处理中进行细化。
聚类完成后,我们可以对每个聚类执行RANSAC:
errors = []
fig, axe = plt.subplots()
for i in range(n):
q = cluster.labels_ == i
x = X[q, 0].reshape(-1, 1)
y = X[q, 1]
regressor = RANSACRegressor()
regressor.fit(x, y)
errors.append(X[:, 1] - regressor.predict(X[:,0].reshape(-1, 1)))
axe.scatter(x, y, marker="o")
axe.plot(x, regressor.predict(x))
axe.scatter(X[:, 0], X[:, 1], marker="x", color="black")
axe.grid()
E = np.array(errors).T
然后如果需要细化聚类,可以使用误差E(点到每条线的垂直距离)对聚类进行后处理,并对校正后的聚类再次执行RANSAC。
cluster2 = HDBSCAN(min_cluster_size=5, min_samples=3, cluster_selection_epsilon=0.15)
cluster2.fit(E)
# array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])