高效地对 numpy 数组的 k 个最低元素进行排序,同时使用 argpartition 排除一些索引

问题描述 投票:0回答:1

我正在尝试实现最有效的方法来仅对数组的 k 个元素进行排序,同时从排序中排除一些元素。目前,我正在使用以下利用 numpy argpartition 的函数:

def find_cads_indices_closest_to_target_distance(
    cad_idx,
    embeddings,
    target_distance,
    closest_to_target_distance_count=1,
    excluded_indices=None,
    metric="cosine",
):
    # Include cad_idx in excluded_indices by default
    if excluded_indices is None:
        excluded_indices = [cad_idx]
    elif cad_idx not in excluded_indices:
        excluded_indices.append(cad_idx)

    reference_embedding = embeddings[cad_idx]
    # Compute distances based on the specified metric
    if metric == "cosine":
        distances = cosine_distances([reference_embedding], embeddings)[0]
    elif metric == "euclidean":
        distances = euclidean_distances([reference_embedding], embeddings)[0]
    else:
        raise ValueError("Invalid metric. Choose 'cosine' or 'euclidean'.")

    differences = np.abs(distances - target_distance)

    # Find indices of the embeddings with the smallest differences
    indices_of_closest_to_target = np.argpartition(
        differences, closest_to_target_distance_count + len(excluded_indices)
    )[: closest_to_target_distance_count + len(excluded_indices)]

    # Exclude the indices in excluded_indices from the list
    indices_of_closest_to_target = [
        idx for idx in indices_of_closest_to_target if idx not in excluded_indices
    ]

    # If there are more than samples_count indices, trim the list
    indices_of_closest_to_target = indices_of_closest_to_target[
        :closest_to_target_distance_count
    ]

    # Optional, sort indices by distance
    indices_of_closest_to_target = sorted(
        indices_of_closest_to_target, key=lambda idx: differences[idx]
    )

    # Return the cads corresponding to the closest embeddings and the distances to them
    return (
        indices_of_closest_to_target,
        distances[indices_of_closest_to_target],
    )

我对此不满意,因为它对需要的更多元素进行了排序。我尝试在排序过程中本地忽略索引,但没有成功。如果您可以通过使用一些复杂的 argpartition 或其他方式来帮助我,我会接受!

谢谢!

python arrays numpy sorting
1个回答
0
投票

这要看情况。使用的空间量和处理时间之间可能需要权衡,并且权衡可能取决于数据大小、

k
的值以及要排除的项目数。

无论哪种方式,您都可以通过传递

argpartition
来让
range(k)
进行排序,而不是仅仅传递
k

如果

k
相对较小,并且要排除的项目数量很少,您可能希望避免复制大数据数组,并且对一些额外的项目进行排序可能并不重要。在这种情况下,你可以尝试这样的事情:

def best_k_simple(data, k=1, exclude=[]):
    k_plus_excludes = k + len(exclude)
    # Get best k_plus_excludes indexes, in sorted order.
    best_indexes_sorted = np.argpartition(data, range(k_plus_excludes))
    # Drop excluded indexes if in best.
    best_indexes_sorted = [i for i in best_indexes_sorted if i not in exclude]
    # Trim to just k best indexes.
    best_indexes_sorted = best_indexes_sorted[:k]
    return (best_indexes_sorted, data[best_indexes_sorted])

演示

import numpy as np

# Define toy data.
data = np.arange(8, 0, -1)
# array([8, 7, 6, 5, 4, 3, 2, 1])

# Get best 3, ignoring items 2 and 6 (data values 6 and 2).
best_k_simple(data, k=3, exclude=[2, 6])
# ([7, 5, 4], array([1, 3, 4]))

如果您确实想避免对要排除的项目进行排序,您可能必须复制要排除的项目的数据副本,如下所示:

def best_k_limited_sorting(data, k=1, exclude=[]):
    data_indexes = range(len(data))
    included_indexes = np.array(list(set(data_indexes) - set(exclude)))
    data_included = data[included_indexes]
    # Get best k indexes in data_included, in sorted order.
    best_indexes = np.argpartition(data_included, range(k))[:k]
    # Map to indexes in original data.
    best_data_indexes = included_indexes[best_indexes]
    return (best_data_indexes, data[best_data_indexes])

演示

# Get best 3, ignoring items 2 and 6 (data values 6 and 2).
best_k_limited_sorting(data, k=3, exclude=[2, 6])
# (array([7, 5, 4]), array([1, 3, 4]))
© www.soinside.com 2019 - 2024. All rights reserved.