我正在尝试实现最有效的方法来仅对数组的 k 个元素进行排序,同时从排序中排除一些元素。目前,我正在使用以下利用 numpy argpartition 的函数:
def find_cads_indices_closest_to_target_distance(
cad_idx,
embeddings,
target_distance,
closest_to_target_distance_count=1,
excluded_indices=None,
metric="cosine",
):
# Include cad_idx in excluded_indices by default
if excluded_indices is None:
excluded_indices = [cad_idx]
elif cad_idx not in excluded_indices:
excluded_indices.append(cad_idx)
reference_embedding = embeddings[cad_idx]
# Compute distances based on the specified metric
if metric == "cosine":
distances = cosine_distances([reference_embedding], embeddings)[0]
elif metric == "euclidean":
distances = euclidean_distances([reference_embedding], embeddings)[0]
else:
raise ValueError("Invalid metric. Choose 'cosine' or 'euclidean'.")
differences = np.abs(distances - target_distance)
# Find indices of the embeddings with the smallest differences
indices_of_closest_to_target = np.argpartition(
differences, closest_to_target_distance_count + len(excluded_indices)
)[: closest_to_target_distance_count + len(excluded_indices)]
# Exclude the indices in excluded_indices from the list
indices_of_closest_to_target = [
idx for idx in indices_of_closest_to_target if idx not in excluded_indices
]
# If there are more than samples_count indices, trim the list
indices_of_closest_to_target = indices_of_closest_to_target[
:closest_to_target_distance_count
]
# Optional, sort indices by distance
indices_of_closest_to_target = sorted(
indices_of_closest_to_target, key=lambda idx: differences[idx]
)
# Return the cads corresponding to the closest embeddings and the distances to them
return (
indices_of_closest_to_target,
distances[indices_of_closest_to_target],
)
我对此不满意,因为它对需要的更多元素进行了排序。我尝试在排序过程中本地忽略索引,但没有成功。如果您可以通过使用一些复杂的 argpartition 或其他方式来帮助我,我会接受!
谢谢!
这要看情况。使用的空间量和处理时间之间可能需要权衡,并且权衡可能取决于数据大小、
k
的值以及要排除的项目数。
无论哪种方式,您都可以通过传递
argpartition
来让 range(k)
进行排序,而不是仅仅传递 k
。
如果
k
相对较小,并且要排除的项目数量很少,您可能希望避免复制大数据数组,并且对一些额外的项目进行排序可能并不重要。在这种情况下,你可以尝试这样的事情:
def best_k_simple(data, k=1, exclude=[]):
k_plus_excludes = k + len(exclude)
# Get best k_plus_excludes indexes, in sorted order.
best_indexes_sorted = np.argpartition(data, range(k_plus_excludes))
# Drop excluded indexes if in best.
best_indexes_sorted = [i for i in best_indexes_sorted if i not in exclude]
# Trim to just k best indexes.
best_indexes_sorted = best_indexes_sorted[:k]
return (best_indexes_sorted, data[best_indexes_sorted])
演示
import numpy as np
# Define toy data.
data = np.arange(8, 0, -1)
# array([8, 7, 6, 5, 4, 3, 2, 1])
# Get best 3, ignoring items 2 and 6 (data values 6 and 2).
best_k_simple(data, k=3, exclude=[2, 6])
# ([7, 5, 4], array([1, 3, 4]))
如果您确实想避免对要排除的项目进行排序,您可能必须复制要排除的项目的数据副本,如下所示:
def best_k_limited_sorting(data, k=1, exclude=[]):
data_indexes = range(len(data))
included_indexes = np.array(list(set(data_indexes) - set(exclude)))
data_included = data[included_indexes]
# Get best k indexes in data_included, in sorted order.
best_indexes = np.argpartition(data_included, range(k))[:k]
# Map to indexes in original data.
best_data_indexes = included_indexes[best_indexes]
return (best_data_indexes, data[best_data_indexes])
演示
# Get best 3, ignoring items 2 and 6 (data values 6 and 2).
best_k_limited_sorting(data, k=3, exclude=[2, 6])
# (array([7, 5, 4]), array([1, 3, 4]))