在以下代码中使用
thrust::sort_by_key
是否合法?
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/sort.h>
#include <thrust/advance.h>
#include <thrust/copy.h>
#include <iterator>
#include <iostream>
int main()
{
int init[] = {2, 0, 1, 3, 4};
const thrust::device_vector< int > v{std::cbegin(init), std::cend(init)};
thrust::device_vector< std::intptr_t > index{v.size()};
thrust::sequence(index.begin(), index.end());
auto key =
thrust::make_permutation_iterator(
thrust::make_transform_iterator(
v.cbegin(),
thrust::identity< thrust::tuple< int > >{}),
index.cbegin());
thrust::sort_by_key(
key,
thrust::next(key, index.size()),
index.begin());
thrust::copy(
index.cbegin(), index.cend(),
std::ostream_iterator< std::intptr_t >(std::cout, ", "));
std::cout << std::endl;
}
这里
index
数组指向v
值数组。我想在上面的排序之后有一个v.index
的“排序视图”是视图,即[v[i] for i in index]
(pythonic伪代码)被排序。
identity
转换的技巧在这里至关重要:它将index
指向的值在v
转换为单元素元组。 thrust::tuple
是一个类并且具有 operator =
,它不是仅针对左值的 cv-ref 限定的,因此可以用于由于取消引用 transform_iterator
而返回的右值。 thrust::tuple< int >(1) = 2;
是一个合法的声明,实际上是一个空操作,因为左边的值在赋值后立即下降。因此,sort_by_key
中的键交换都是空操作,真正的排序发生在键值排序的“值”部分。同样不是,v
在这里是不可变的(结果 v.cbegin()
是 const 迭代器)。
据我所知,Thrust 的开发人员通常假设所有可调用对象都是 idempotent。我相信这里没有违反假设,因为只有可调用 (
thrust::identity
) 的参数发生了变化,而不是可调用的状态。但另一方面,Thrust 花式迭代器的任何叠加都可以被视为函数的组合(比如,permutation_iterator
是一个简单的映射)。
在
sort_by_key
index
是读写的。它可以被隐含的实施规则禁止。是正确的代码吗?
虽然我无法回答这个问题,如果你的实现依赖于未定义的行为或算法细节来保证
index
上不存在竞争条件,有一种方法可以在 Thrust 中使用 thrust::sort
而不是 thrust::sort_by_key
实现“argsort”这更容易阅读、理解和争论:
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/advance.h>
#include <thrust/copy.h>
#include <thrust/iterator/counting_iterator.h>
#include <iterator>
#include <iostream>
int main()
{
int init[] = {2, 0, 1, 3, 4};
const thrust::device_vector< int > v{std::cbegin(init), std::cend(init)};
// optimization to avoid unnecessary initialization of index to zero
auto const seq_iter =
thrust::make_counting_iterator(
static_cast< std::intptr_t >(0));
thrust::device_vector< std::intptr_t > index{seq_iter,
thrust::next(seq_iter, v.size())};
auto const v_ptr = v.data();
thrust::sort(
index.begin(), index.end(),
[v_ptr] __host__ __device__ (std::intptr_t left_idx, std::intptr_t right_idx)
{
return v_ptr[left_idx] < v_ptr[right_idx];
});
thrust::copy(
index.cbegin(), index.cend(),
std::ostream_iterator< std::intptr_t >(std::cout, ", "));
std::cout << std::endl;
}
由于设备 lambda,nvcc 需要
-extended-lambda
标志来编译它。可以使用命名仿函数而不是 lambda 来实现 sort
的比较器,以避免需要此标志。