如何将 np.float128 注册为有效的 numba 类型,以便我可以对数组求和?

问题描述 投票:0回答:1

我想编写一个 numba 函数,该函数获取 np.float128s 数组的总和。 这些在我的设置中具有 80 位精度,但如果更容易将它们转换为真正的 float128,我也会对此感到满意。我的目标是能够快速对 np.float128 数组求和且不损失精度,最终将 np.float128 返回到我的 python 代码。

为此,我尝试从 https://github.com/gmarkall/numba-bfloat16/blob/main/prototype.py 复制模型。 我跳过与 CUDA 相关的部分,因为我不会编写 CUDA 代码,但我不知道是否需要用某些东西替换它们。 到目前为止我已经:

import numpy as np from numba.core.types.scalars import Number from numba.np import numpy_support class Float128(Number): def __init__(self, *args, **kws): super().__init__(name='float128') float128_type = Float128() numpy_support.FROM_DTYPE[np.dtype(np.float128)] = float128_type

链接的下一部分是:

@intrinsic def bfloat16_add(typingctx, a, b): sig = bfloat16_type(bfloat16_type, bfloat16_type) def codegen(context, builder, sig, args): i16 = ir.IntType(16) function_type = ir.FunctionType(i16, [i16, i16]) instruction = ("{.reg.b16 one; " "mov.b16 one, 0x3f80U; " "fma.rn.bf16 $0, $1, one, $2;}") asm = ir.InlineAsm(function_type, instruction, "=h,h,h") return builder.call(asm, args) return sig, codegen

我的 float128 应该用什么替换 bfloat16_add ?我想我需要为 numba 定义加法....?

python numba
1个回答
0
投票
a+b

被实现为

one * a + b
- 对于 BF16 LLVM 仅实现融合乘法和加法 (FMA),这里使用常量
1
来滥用。
您需要在此处在 LLVM 中实现 FP128 添加。这意味着从 128 位中提取 80 个值位,将它们用作 

x86_fp80

值,将它们相加,添加回 48 位填充,然后返回。无需 FMA 破解。

    

© www.soinside.com 2019 - 2024. All rights reserved.