推力中的argsort

问题描述 投票:0回答:1

在以下代码中使用

thrust::sort_by_key
是否合法?

#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/sort.h>
#include <thrust/advance.h>
#include <thrust/copy.h>
    
#include <iterator>
#include <iostream>
    
int main()
{
    int init[] = {2, 0, 1, 3, 4};
    const thrust::device_vector< int > v{std::cbegin(init), std::cend(init)};
        
    thrust::device_vector< std::intptr_t > index{v.size()};
    thrust::sequence(index.begin(), index.end());
        
    auto key = 
        thrust::make_permutation_iterator(
            thrust::make_transform_iterator(
                v.cbegin(), 
                thrust::identity< thrust::tuple< int > >{}),
            index.cbegin());
    thrust::sort_by_key(
        key,
        thrust::next(key, index.size()),
        index.begin());
        
    thrust::copy(
        index.cbegin(), index.cend(),
        std::ostream_iterator< std::intptr_t >(std::cout, ", "));
    std::cout << std::endl;
}

这里

index
数组指向
v
值数组。我想在上面的排序之后有一个
v.index
的“排序视图”是视图,即
[v[i] for i in index]
(pythonic伪代码)被排序。

identity
转换的技巧在这里至关重要:它将
index
指向的值在
v
转换为单元素元组。
thrust::tuple
是一个类并且具有
operator =
,它不是仅针对左值的 cv-ref 限定的,因此可以用于由于取消引用
transform_iterator
而返回的右值。
thrust::tuple< int >(1) = 2;
是一个合法的声明,实际上是一个空操作,因为左边的值在赋值后立即下降。因此,
sort_by_key
中的键交换都是空操作,真正的排序发生在键值排序的“值”部分。同样不是,
v
在这里是不可变的(结果
v.cbegin()
是 const 迭代器)。

据我所知,Thrust 的开发人员通常假设所有可调用对象都是 idempotent。我相信这里没有违反假设,因为只有可调用 (

thrust::identity
) 的参数发生了变化,而不是可调用的状态。但另一方面,Thrust 花式迭代器的任何叠加都可以被视为函数的组合(比如,
permutation_iterator
是一个简单的映射)。

sort_by_key
index
是读写的。它可以被隐含的实施规则禁止。是正确的代码吗?

c++ algorithm sorting cuda thrust
1个回答
0
投票

虽然我无法回答这个问题,如果你的实现依赖于未定义的行为或算法细节来保证

index
上不存在竞争条件,有一种方法可以在 Thrust 中使用
thrust::sort
而不是
thrust::sort_by_key
实现“argsort”这更容易阅读、理解和争论:

#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/advance.h>
#include <thrust/copy.h>
#include <thrust/iterator/counting_iterator.h>

#include <iterator>
#include <iostream>

int main()
{
    int init[] = {2, 0, 1, 3, 4};
    const thrust::device_vector< int > v{std::cbegin(init), std::cend(init)};

    // optimization to avoid unnecessary initialization of index to zero
    auto const seq_iter =
        thrust::make_counting_iterator(
            static_cast< std::intptr_t >(0));

    thrust::device_vector< std::intptr_t > index{seq_iter,
                                                 thrust::next(seq_iter, v.size())};
    
    auto const v_ptr = v.data();

    thrust::sort(
        index.begin(), index.end(),
        [v_ptr] __host__ __device__ (std::intptr_t left_idx, std::intptr_t right_idx)
        {
            return v_ptr[left_idx] < v_ptr[right_idx];
        });

    thrust::copy(
        index.cbegin(), index.cend(),
        std::ostream_iterator< std::intptr_t >(std::cout, ", "));
    std::cout << std::endl;
}

由于设备 lambda,nvcc 需要

-extended-lambda
标志来编译它。可以使用命名仿函数而不是 lambda 来实现
sort
的比较器,以避免需要此标志。

© www.soinside.com 2019 - 2024. All rights reserved.