我想在 numba 中的函数中分派第二个参数的类型,但失败了。
如果它是一个整数,那么应该返回一个向量, 如果它本身是一个整数数组,那么应该返回一个矩阵。
第一个代码不起作用
@njit
def test_dispatch(X, indices):
if isinstance(indices, nb.int64):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
elif isinstance(indices, nb.int64[:]):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
而第二个带有
else
的则可以。
@njit
def test_dispatch(X, indices):
if isinstance(indices, nb.int64):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
else:
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
我猜问题是通过
nb.int64[:]
的类型声明,但我没有让它以任何其他方式工作。
你有想法吗?
您不应该在这样的 JIT 函数中使用
isinstance
,而应使用专门为此目的而设计的 @generated_jit
。这使得 Numba 能够更快地生成代码,因为仅针对每种情况编译函数的一部分,而不是针对每种专业化的所有情况。此外,isinstance
是实验性的,正如 Numba 在执行第一个代码时发出的警告中所指定的(报告警告以供用户阅读;))。
这是一个关于泛型类型的推理示例:
import numba as nb
import numpy as np
@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
if isinstance(indices, nb.types.Integer):
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array):
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
return test_dispatch_vector
else:
assert False # Unsupported
以下是关于特定类型的推理示例:
import numba as nb
import numpy as np
@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
if indices == nb.types.int64:
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and indices.dtype == nb.types.int64:
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
return test_dispatch_vector
else:
assert False # Unsupported
专门请求 64 位整数可能有点限制太多,因此我建议您混合通用类型测试和特定类型测试。出于同样的原因,您应该避免直接测试数组是否属于特定类型,因为它们通常可以是连续的或不连续的,或者可以包含与您的函数兼容的项目类型。
请注意,通用 JIT 函数旨在生成针对目标输入类型(而不是值)单独编译的函数。