CUDA中从float*到float3*的转换是否安全?

问题描述 投票:0回答:1

我刚刚开始接触CUDA代码,它有点像过去的爆炸,大量的指针访问和通过指针进行类型转换,使用的是 reinterpret_cast. 我有一个特殊的情况,我想检查一下,我在代码中看到了以下类型惩罚的例子。

__device__ void func(__restrict__ float* const points, size_t size, __restrict__ float* outputPoints) {

    for (size_t index = 0; index < size; index += 3) {
        float3* const point = reinterpret_cast<float3* const>(points + index);
        float3* const output = reinterpret_cast<float3* const>(outputPoints + index);
        // operations using point;
    }
}

在CUDA中,你会得到一个结构 float3 这看起来像。

struct float3 {
    float x, y, z
}

这种行为能保证安全吗?这显然是某种类型的惩罚,但我非常担心可能会有一些填充或对齐或其他东西会以这种方式破坏访问。如果有人能够进一步深入了解cuda编译器如何处理这个问题,因为我知道它也做了一些很重的优化。这些会不会造成问题?

c++ cuda type-conversion type-safety
1个回答
4
投票

CUDA保证那些内置类型的大小在主机和设备之间是一致的,没有填充干预(对于用户定义的结构和类,不存在这样的保证)。

设备上有对齐的基本要求,比如你读取的存储必须与读取的大小对齐。所以你不能读取一个 float3 从任意字节边界读取,但你从32位对齐的边界读取是安全的,CUDA在主机和设备上公开的内存分配API保证了必要的对齐,使你发布的代码是安全的。

你所发的代码(修改后可以战胜死代码删除),基本上只是发出三个32位的加载和三个32位的存储。CUDA只有有限的原生事务大小,而且它们并不能映射到每个线程请求的96位,所以这样做完全没有优化。

__device__ void func(float* const points, size_t size, float* outputPoints) {

    for (size_t index = 0; index < size; index += 3) {
        float3* point = reinterpret_cast<float3*>(points + index);
        float3* output = reinterpret_cast<float3*>(outputPoints + index);

    float3 val = *point;
    val.x += 1.f; val.y += 2.f; val.z += 3.f;
    *output = val;
    }
}

所以这样做完全没有优化:

$ nvcc -arch=sm_75 -std=c++11 -dc -ptx fffloat3.cu 
$ tail -40 fffloat3.ptx 
    // .globl   _Z4funcPfmS_
.visible .func _Z4funcPfmS_(
    .param .b64 _Z4funcPfmS__param_0,
    .param .b64 _Z4funcPfmS__param_1,
    .param .b64 _Z4funcPfmS__param_2
)
{
    .reg .pred  %p<3>;
    .reg .f32   %f<7>;
    .reg .b64   %rd<14>;


    ld.param.u64    %rd11, [_Z4funcPfmS__param_0];
    ld.param.u64    %rd8, [_Z4funcPfmS__param_1];
    ld.param.u64    %rd12, [_Z4funcPfmS__param_2];
    setp.eq.s64 %p1, %rd8, 0;
    mov.u64     %rd13, 0;
    @%p1 bra    BB6_2;

BB6_1:
    ld.f32  %f1, [%rd11];
    ld.f32  %f2, [%rd11+4];
    ld.f32  %f3, [%rd11+8];
    add.f32     %f4, %f1, 0f3F800000;
    add.f32     %f5, %f2, 0f40000000;
    add.f32     %f6, %f3, 0f40400000;
    st.f32  [%rd12], %f4;
    st.f32  [%rd12+4], %f5;
    st.f32  [%rd12+8], %f6;
    add.s64     %rd12, %rd12, 12;
    add.s64     %rd11, %rd11, 12;
    add.s64     %rd13, %rd13, 3;
    setp.lt.u64 %p2, %rd13, %rd8;
    @%p2 bra    BB6_1;

BB6_2:
    ret;
}

也就是说,所有的铸造在语法上是虚假的,也是毫无意义的。

如果你要改成 float2,这是每个线程的64位请求,可以矢量化,所以得到这个。

.visible .func _Z4funcPfmS_(
    .param .b64 _Z4funcPfmS__param_0,
    .param .b64 _Z4funcPfmS__param_1,
    .param .b64 _Z4funcPfmS__param_2
)
{
    .reg .pred  %p<3>;
    .reg .f32   %f<7>;
    .reg .b64   %rd<14>;


    ld.param.u64    %rd12, [_Z4funcPfmS__param_0];
    ld.param.u64    %rd8, [_Z4funcPfmS__param_1];
    ld.param.u64    %rd11, [_Z4funcPfmS__param_2];
    setp.eq.s64 %p1, %rd8, 0;
    mov.u64     %rd13, 0;
    @%p1 bra    BB6_2;

BB6_1:
    ld.v2.f32   {%f1, %f2}, [%rd12];
    add.f32     %f5, %f2, 0f40000000;
    add.f32     %f6, %f1, 0f3F800000;
    st.v2.f32   [%rd11], {%f6, %f5};
    add.s64     %rd12, %rd12, 8;
    add.s64     %rd11, %rd11, 8;
    add.s64     %rd13, %rd13, 2;
    setp.lt.u64 %p2, %rd13, %rd8;
    @%p2 bra    BB6_1;

BB6_2:
    ret;
}

请注意,加载和存储现在使用的是 矢量化版本的指令。同样的 float4:

    // .globl   _Z4funcPfmS_
.visible .func _Z4funcPfmS_(
    .param .b64 _Z4funcPfmS__param_0,
    .param .b64 _Z4funcPfmS__param_1,
    .param .b64 _Z4funcPfmS__param_2
)
{
    .reg .pred  %p<3>;
    .reg .f32   %f<12>;
    .reg .b64   %rd<14>;


    ld.param.u64    %rd12, [_Z4funcPfmS__param_0];
    ld.param.u64    %rd8, [_Z4funcPfmS__param_1];
    ld.param.u64    %rd11, [_Z4funcPfmS__param_2];
    setp.eq.s64 %p1, %rd8, 0;
    mov.u64     %rd13, 0;
    @%p1 bra    BB6_2;

BB6_1:
    ld.v4.f32   {%f1, %f2, %f3, %f4}, [%rd12];
    add.f32     %f9, %f3, 0f40400000;
    add.f32     %f10, %f2, 0f40000000;
    add.f32     %f11, %f1, 0f3F800000;
    st.v4.f32   [%rd11], {%f11, %f10, %f9, %f4};
    add.s64     %rd12, %rd12, 8;
    add.s64     %rd11, %rd11, 8;
    add.s64     %rd13, %rd13, 2;
    setp.lt.u64 %p2, %rd13, %rd8;
    @%p2 bra    BB6_1;

BB6_2:
    ret;
}

TLDR:你的担心是有道理的,但API和编译器会理智地处理合理的情况,但在试图编写 "最佳代码 "之前,你应该非常熟悉对齐和硬件限制,因为除非你清楚地知道自己在做什么,否则有可能写出很多毫无意义的废话。

© www.soinside.com 2019 - 2024. All rights reserved.