numba 按类型调度

问题描述 投票:0回答:1

我想在 numba 中的函数中分派第二个参数的类型,但失败了。

如果它是一个整数,那么应该返回一个向量, 如果它本身是一个整数数组,那么应该返回一个矩阵。

第一个代码不起作用

@njit
def test_dispatch(X, indices):
    if isinstance(indices, nb.int64):
        ref_pos = np.empty(3, np.float64)
        ref_pos[:] = X[:, indices]
        return ref_pos
    elif isinstance(indices, nb.int64[:]):
        ref_pos = np.empty((3, len(indices)), np.float64)
        ref_pos[:, :] = X[:, indices]
        return ref_pos

而第二个带有

else
的则可以。

@njit
def test_dispatch(X, indices):
    if isinstance(indices, nb.int64):
        ref_pos = np.empty(3, np.float64)
        ref_pos[:] = X[:, indices]
        return ref_pos
    else:
        ref_pos = np.empty((3, len(indices)), np.float64)
        ref_pos[:, :] = X[:, indices]
        return ref_pos

我猜问题是通过

nb.int64[:]
的类型声明,但我没有让它以任何其他方式工作。 你有想法吗?

python types numba
1个回答
0
投票

您不应该在这样的 JIT 函数中使用

isinstance
,而应使用专门为此目的而设计的
@generated_jit
。这使得 Numba 能够更快地生成代码,因为仅针对每种情况编译函数的一部分,而不是针对每种专业化的所有情况。此外,
isinstance
是实验性的,正如 Numba 在执行第一个代码时发出的警告中所指定的(报告警告以供用户阅读;))。

这是一个关于泛型类型的推理示例:

import numba as nb
import numpy as np

@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
    if isinstance(indices, nb.types.Integer):
        def test_dispatch_scalar(X, indices):
            ref_pos = np.empty(3, np.float64)
            ref_pos[:] = X[:, indices]
            return ref_pos
        return test_dispatch_scalar
    elif isinstance(indices, nb.types.Array):
        def test_dispatch_vector(X, indices):
            ref_pos = np.empty((3, len(indices)), np.float64)
            ref_pos[:, :] = X[:, indices]
            return ref_pos
        return test_dispatch_vector
    else:
        assert False # Unsupported

以下是关于特定类型的推理示例:

import numba as nb
import numpy as np

@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
    if indices == nb.types.int64:
        def test_dispatch_scalar(X, indices):
            ref_pos = np.empty(3, np.float64)
            ref_pos[:] = X[:, indices]
            return ref_pos
        return test_dispatch_scalar
    elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and indices.dtype == nb.types.int64:
        def test_dispatch_vector(X, indices):
            ref_pos = np.empty((3, len(indices)), np.float64)
            ref_pos[:, :] = X[:, indices]
            return ref_pos
        return test_dispatch_vector
    else:
        assert False # Unsupported

专门请求 64 位整数可能有点限制太多,因此我建议您混合通用类型测试和特定类型测试。出于同样的原因,您应该避免直接测试数组是否属于特定类型,因为它们通常可以是连续的或不连续的,或者可以包含与您的函数兼容的项目类型。

请注意,通用 JIT 函数旨在生成针对目标输入类型(而不是值)单独编译的函数。

© www.soinside.com 2019 - 2024. All rights reserved.