我已经在 CUDA 上运行了一个数据结构并收集数据如下:
struct SearchDataOnDevice
{
size_t npair;
int * id1;
int * id2;
};
我想删除重复的 id 对,带有和不带有名为“same_id_src”的选项,当 same_id_src=true 时,<0, 5> 和 <5, 0> 重复,并且 <5, 0> 应删除。当same_id_src=false时,两对都应该保留。
我是 CUDA 和推力库的新手,有人可以帮助快速提示吗?
这是一种可能的方法:
thrust::sort
)thrust::transform
)thrust::copy_if
对排序对进行流压缩以产生去重结果需要处理以下情况:
<0 5> and
<5 0>are considered "identical" or not, is handled via modification to the sort functor, as well as modification to the transform functor. In the sort functor case, we simply reorder, for comparison purposes, each pair such that the lower ID appears first in the pair. We must arrange the sort functor carefully, so that the case of
<0 5>is chosen preferentially over the case of
<5 0>`,当特殊条件为真时。
这是一个例子:
# cat t76.cu
#include <thrust/sort.h>
#include <thrust/copy.h>
#include <thrust/transform.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <iostream>
#include <cstdlib>
struct my_sort_functor
{
bool same_id_src;
my_sort_functor(bool _same_id_src) : same_id_src(_same_id_src) {};
template <typename T1, typename T2>
__host__ __device__
bool operator()(T1 t1, T2 t2){
int t1a = thrust::get<0>(t1);
int t1b = thrust::get<1>(t1);
int t2a = thrust::get<0>(t2);
int t2b = thrust::get<1>(t2);
if (same_id_src) {//need to possibly reorder each pair for testing
bool t1s = (t1a > t1b);
bool t2s = (t2a > t2b);
// sort on smaller id first
if ((t1s?t1b:t1a) < (t2s?t2b:t2a)) return true;
if ((t1s?t1b:t1a) > (t2s?t2b:t2a)) return false;
// then sort on larger id
if ((t1s?t1a:t1b) < (t2s?t2a:t2b)) return true;
if ((t1s?t1a:t1b) > (t2s?t2a:t2b)) return false;
// then sort based on the equality case
// we prefer to choose <0,5> over <5,0>
// so order that one first
return !t1s;}
else { // no reordering of pairs
// sort on first id
if (t1a < t2a) return true;
if (t1a > t2a) return false;
// sort on second id
if (t1b < t2b) return true;
return false;}
}
};
struct my_transform_functor
{
bool same_id_src;
my_transform_functor(bool _same_id_src) : same_id_src(_same_id_src) {};
template <typename T1, typename T2>
__host__ __device__
bool operator()(T1 t1, T2 t2){
if ((thrust::get<0>(t1) == thrust::get<0>(t2)) && (thrust::get<1>(t1) == thrust::get<1>(t2))) return false;
if (same_id_src)
if ((thrust::get<0>(t1) == thrust::get<1>(t2)) && (thrust::get<1>(t1) == thrust::get<0>(t2))) return false;
return true;
}
};
struct my_copy_predicate
{
__host__ __device__
bool operator()(bool t) { return t;}
};
int main(int argc, char *argv[]){
// data setup
int d1[] = {0, 1, 2, 3, 4, 5, 0};
int d2[] = {5, 2, 1, 2, 1, 0, 5};
bool same_id_src = true;
if (argc > 1) same_id_src = false;
size_t npair = sizeof(d1)/sizeof(d1[0]);
int *id1, *id2;
cudaMalloc(&id1, sizeof(d1));
cudaMalloc(&id2, sizeof(d1));
cudaMemcpy(id1, d1, sizeof(d1), cudaMemcpyHostToDevice);
cudaMemcpy(id2, d2, sizeof(d1), cudaMemcpyHostToDevice);
auto dp_id1 = thrust::device_ptr<int>(id1);
auto dp_id2 = thrust::device_ptr<int>(id2);
auto dzip = thrust::make_zip_iterator(thrust::make_tuple(dp_id1, dp_id2));
thrust::device_vector<bool> stencil(npair, true);
thrust::device_vector<int> r1(npair);
thrust::device_vector<int> r2(npair);
auto rzip = thrust::make_zip_iterator(thrust::make_tuple(r1.begin(), r2.begin()));
// step 1: sort
thrust::sort(dzip, dzip+npair, my_sort_functor(same_id_src));
// step 2: mark pairs to be kept in stencil
thrust::transform(dzip, dzip+npair-1, dzip+1, stencil.begin()+1, my_transform_functor(same_id_src));
// step 3: copy if, using stencil
int rsize = thrust::copy_if(dzip, dzip+npair, stencil.begin(), rzip, my_copy_predicate()) - rzip;
// display result
thrust::copy_n(r1.begin(), rsize, std::ostream_iterator<int>(std::cout, " "));
std::cout << std::endl;
thrust::copy_n(r2.begin(), rsize, std::ostream_iterator<int>(std::cout, " "));
std::cout << std::endl;
}
# nvcc -o t76 t76.cu
# compute-sanitizer ./t76
========= COMPUTE-SANITIZER
0 1 4 3
5 2 1 2
========= ERROR SUMMARY: 0 errors
# compute-sanitizer ./t76 1
========= COMPUTE-SANITIZER
0 1 2 3 4 5
5 2 1 2 1 0
========= ERROR SUMMARY: 0 errors
#
当我们不指定命令行参数时,特殊情况被认为是正确的,并且额外的“重复项”将被删除。当我们指定命令行参数时,特殊情况被认为是错误的。