我一直在尝试在 CUDA 中实现无锁队列,但由于某种原因我的代码出现死锁,特别是当我运行下面的测试用例时。
代码:
#include <atomic>
#include "stdio.h"
class lockless_pq{
private:
int* queue;
int head_index;
int tail_index;
public:
int MAX_SIZE;
lockless_pq(int max_size){
MAX_SIZE = max_size;
int* host_queue = new int[max_size];
//initialize entire queue with nullptrs
for (int i = 0; i < MAX_SIZE; i++){
host_queue[i] = -1;//atomic<nullptr>;
}
cudaMalloc((int**)&queue, sizeof(int)*MAX_SIZE);
cudaMemcpy(queue, host_queue, sizeof(int)*MAX_SIZE, cudaMemcpyHostToDevice);
head_index = 0;
tail_index = 0;
}
__device__ void push(int val){
int old_tail_index = tail_index;
while(tail_index == (head_index + MAX_SIZE-1) % MAX_SIZE || old_tail_index != atomicCAS(&tail_index, old_tail_index, (old_tail_index+1) % MAX_SIZE)){
old_tail_index = tail_index;
//printf("push loop\n");
}
//once here, know spot has been reserved to write in!
queue[old_tail_index] = val;
}
__device__ void pop(int* val){
int old_head_index = head_index;
while(head_index == tail_index || old_head_index != atomicCAS(&head_index, old_head_index, (old_head_index+1) % MAX_SIZE)){
old_head_index = head_index;
}
*val = queue[old_head_index];
}
bool is_empty(){
//if head == tail, or element after head is nullptr then empty!
if (head_index == tail_index){
return true;
}
return false;
}
};`
测试用例:
`#include "test.h"
#include "stdio.h"
__global__ void simple_push_pop_kernel(lockless_pq* pq){
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx == 0){
printf("entering\n");
}
pq->push(idx);=
pq->pop(&idx);
}
int main(int argc, char** argv){
//launch kernel with each thread adding its index to the pq and then popping it off
lockless_pq* pq = new lockless_pq(10);
printf("pq size: %d\n", pq->MAX_SIZE);
lockless_pq* dev_pq;
cudaMalloc((lockless_pq**)&dev_pq, sizeof(lockless_pq));
cudaMemcpy(dev_pq, pq, sizeof(lockless_pq), cudaMemcpyHostToDevice);
simple_push_pop_kernel<<<2, 5>>>(dev_pq);
cudaDeviceSynchronize();
}`
当我运行这个测试用例时,第一个块中线程的推送和弹出成功,但在第二个块中,没有一个推送完成并且我有一个无限循环。有趣的是,当我在 push 的 while 循环中添加 print 语句时,这个问题一直消失。
根据我一直在做的一些打印语句调试,我相信我的问题是因为扭曲,即如果一个扭曲线程到达我的 while 循环并逃脱它,其他线程将等待执行直到它完成,可能导致他们陷入僵局。我不认为这是一个与块相关的问题,因为我在 1 个块中测试 10 个线程时也遇到了死锁。
虽然我相当确定我的问题与扭曲有关,但我不是 100% 确定它到底是什么或如何开始修复它。
你能帮我弄清楚我的代码有什么问题以及我该如何解决吗?
非常感谢!