Pycuda 调用 cublasDgetrfBatched 执行 LU 分解失败

问题描述 投票:0回答:1

我想使用 scikit-cuda 和 pycuda 对方阵 A 进行 LU 分解。我从 scikit-cuda github 网站上尝试了一些演示代码,一切似乎都很好。

但是,当我尝试在我的小示例中调用低级接口 cublas.cublasDgetrfBatched 时,我失败了,错误代码如下:

PyCUDA WARNING: a clean-up operation failed (dead context maybe?)
cuMemFree failed: an illegal memory access was encountered

这是我的小 python 代码

import numpy as np
import pycuda.autoinit
import skcuda.cublas as cublas
import pycuda.gpuarray as gpuarray

N = 10
N_BATCH = 1  # only 1 matrix to be decomposed
A_SHAPE = (N, N)

a = np.random.rand(*A_SHAPE).astype(np.float64)
a_batch = np.expand_dims(a, axis=0)

a_gpu = gpuarray.to_gpu(a_batch.T.copy())  # transpose a to follow "F" order
p_gpu = gpuarray.zeros(N * N_BATCH, np.int32)
info_gpu = gpuarray.zeros(N_BATCH, np.int32)

cublas_handle = cublas.cublasCreate()
cublas.cublasDgetrfBatched(
    cublas_handle,
    N,
    a_gpu.gpudata,
    N,
    p_gpu.gpudata,
    info_gpu.gpudata,
    N_BATCH,
)

cublas.cublasDestroy(cublas_handle)
print(a_gpu)

我是scikit-cuda的新手。那么,有人可以帮我吗?

cuda pycuda scikit-cuda
1个回答
0
投票

就像@talonmies评论的那样,应该使用指向设备上矩阵地址列表的指针。

import numpy as np
import pycuda.autoinit
import skcuda.cublas as cublas
import pycuda.gpuarray as gpuarray

N = 10
N_BATCH = 1  # only 1 matrix to be decomposed
A_SHAPE = (N, N)

a = np.random.rand(*A_SHAPE).astype(np.float64)
a_batch = np.expand_dims(a, axis=0)

a_gpu = gpuarray.to_gpu(a_batch.T.copy()))  # transpose a to follow "F" order

# use np.array won't work instead of gpuarray.to_gpu.
# .ptr is pointer to matrix a on the device
# can be further revised to programmatically create a list of pointers 
a_gpu_batch = gpuarray.to_gpu(np.asarray([a_gpu.ptr]) 

p_gpu = gpuarray.zeros(N * N_BATCH, np.int32)
info_gpu = gpuarray.zeros(N_BATCH, np.int32)

cublas_handle = cublas.cublasCreate()
cublas.cublasDgetrfBatched(
    cublas_handle,
    N,
    a_gpu_batch.gpudata,
    N,
    p_gpu.gpudata,
    info_gpu.gpudata,
    N_BATCH,
)

cublas.cublasDestroy(cublas_handle)
print(a_gpu)
© www.soinside.com 2019 - 2024. All rights reserved.