假设k=5。给定一个未标记的图像,我们想要识别最接近它的 5 个标记图像(基于 L2 距离),并查询它们的标签(特别是它们的 cifar10 类 ID)。将选择最常见的标签作为用于对未标记图像进行分类的标签。例如,假设 5 个最近邻的标签是:2、0、2、2、8。那么这个未标记图像的预测标签将为 2。
如果 k 个标签之间存在平局,则应使用平局标签中值最低的 class-ID。
这是我的代码:
import numpy as np
from collections import Counter
def predict(dists, labels, k):
# Find the indices of the k-nearest neighbors for each unlabeled image, sorted by descending distance
nearest = np.argsort(dists, axis=1)[:, :k]
# Get the labels of the nearest neighbors
nearest_labels = labels[nearest]
# Count the occurrences of each class ID in the nearest neighbors
counts = [Counter(row) for row in nearest_labels]
# Get the most common class ID for each row of nearest neighbors
most_common = []
for count in counts:
# Check for ties and break them by choosing the label with the lowest value
max_count = max(count.values())
possible_labels = [label for label, count in count.items() if count == max_count]
chosen_label = min(possible_labels)
most_common.append(chosen_label)
return np.array(most_common)
pass
它得到错误:
Calling
predict(array([[0., 1., 1., 1.]]), array([0, 1, 1, 0]), 3)
produced an incorrect result.
Expected a shape-(1,) array with value:
array([0])
Got a shape-(1,) array with value:
array([1])
有人能给我解释一下,或者帮我解决代码中的问题吗?我已经坚持了一段时间。我尝试过实施不同的平局,并使用不同的方法