如何将此快速排序 Python 实现转换为 Numpy 的 argsort 的等价物?

问题描述 投票:0回答:1

我有一个 Python 迭代快速排序实现。我想做一个 argsort 而不是排序,这样生成的数组在排序时具有项目的排名而不是项目本身。

import numpy as np

def argSortPartition(arr, arg, l, h):
    i = (l - 1)
    x = arr[h]

    for j in prange(l, h):
        if arr[j] <= x:
            # increment index of smaller element
            i = i + 1
            arr[i], arr[j] = arr[j], arr[i]
            arg[j] = arg[j] + 1
            arg[i] = arg[i] - 1

    arr[i + 1], arr[h] = arr[h], arr[i + 1]
    arg[i] = arg[i] + 1
    arg[j + 1] = arg[j + 1] + 1
    return (i + 1)


def quickArgSortIterative(arr, start_index, end_index):
    # Create an auxiliary stack
    size = end_index - start_index + 1
    stack = [0] * (size)
    arg = list(range(size))

    # initialize top of stack
    top = -1

    # push initial values of l and h to stack
    top = top + 1
    stack[top] = start_index
    top = top + 1
    stack[top] = end_index

    # Keep popping from stack while is not empty
    while top >= 0:
        # Pop h and l
        end_index = stack[top]
        top = top - 1
        start_index = stack[top]
        top = top - 1

        # Set pivot element at its correct position in
        # sorted array
        p = argSortPartition(arr, arg, start_index, end_index)

        # If there are elements on left side of pivot,
        # then push left side to stack
        if p - 1 > start_index:
            top = top + 1
            stack[top] = start_index
            top = top + 1
            stack[top] = p - 1

        # If there are elements on right side of pivot,
        # then push right side to stack
        if p + 1 < end_index:
            top = top + 1
            stack[top] = p + 1
            top = top + 1
            stack[top] = end_index

    return arg

d = np.array([50, 30, 20, 40, 60])
argsorted = quickArgSortIterative(d, 0, len(d) - 1)
print(argsorted)

>>> Should print [3, 1, 0, 2, 4], but I get [0, 3, 3, 5, 5] on this latest version.

这是我的尝试(排序工作正常,但 arg(应该保存 argsort 的输出)没有。我只需要排序一次就可以提高效率,所以它是一个间接排序操作。

我已经尝试了很多东西,但在这一点上似乎在兜圈子。我知道这是可能的,但我似乎无法找到任何其他语言的参考实现。

这最终将在 Numba CUDA 编译代码中运行,所以我不能只使用 Numpy 的 argsort,也不能使用 Python 的奇特语言结构,如列表理解。必须是穴居人的方式:)

非常感谢任何帮助!

python sorting quicksort np.argsort
1个回答
0
投票

我修好了。这是一个有效的、高性能的 arg 排序实现,它不使用递归并且可以编译为 Numba CUDA:

import numpy as np

def argSortPartition(arr, arg, l, h):
    i = l - 1
    x = arr[h]

    for j in prange(l, h):
        if arr[j] <= x:
            # increment index of smaller element
            i = i + 1
            arr[i], arr[j] = arr[j], arr[i]
            arg[i], arg[j] = arg[j], arg[i]

    i += 1
    arr[i], arr[h] = arr[h], arr[i]
    arg[i], arg[h] = arg[h], arg[i]
    return i


def quickArgSortIterative(arr, start_index, end_index):
    # Create an auxiliary stack
    size = end_index - start_index + 1
    stack = [0] * size
    arg = list(range(size))

    # initialize top of stack
    top = -1

    # push initial values of l and h to stack
    top = top + 1
    stack[top] = start_index
    top = top + 1
    stack[top] = end_index

    # Keep popping from stack while is not empty
    while top >= 0:
        # Pop h and l
        end_index = stack[top]
        top = top - 1
        start_index = stack[top]
        top = top - 1

        # Set pivot element at its correct position in
        # sorted array
        p = argSortPartition(arr, arg, start_index, end_index)

        # If there are elements on left side of pivot,
        # then push left side to stack
        if p - 1 > start_index:
            top = top + 1
            stack[top] = start_index
            top = top + 1
            stack[top] = p - 1

        # If there are elements on right side of pivot,
        # then push right side to stack
        if p + 1 < end_index:
            top = top + 1
            stack[top] = p + 1
            top = top + 1
            stack[top] = end_index

    # must uncomment this loop if returning rankings instead of argsort
    # for i in range(len(arg)):
    #    stack[arg[i]] = i
    
    # to return rankings instead, return stack
    return arg

d = np.array([50, 30, 20, 40, 60])
argsorted = quickArgSortIterative(d, 0, len(d) - 1)
print(argsorted)

这有点复杂 - 希望它能帮助别人。

© www.soinside.com 2019 - 2024. All rights reserved.