如何编译具有可变输入类型的numba jit'ed函数?

问题描述 投票:2回答:1

假设我有一个可以接受intNone类型作为输入参数的函数

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
    np.random.seed(None)
    out = np.random.normal()
    return out

我希望函数只返回一个正态分布的随机数。如果我想要可重复的结果,种子应该是int

get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327

如果我想要随机数字,seed应保留为None。但是,如果我没有传递参数(所以种子默认为None)或明确传递seed=None,那么numba会引发一个TypeError

get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)

如何编写函数,仍然声明签名并使用nopython模式进行此类场景?

我的numba版本是0.43.1

python random signature optional-parameters numba
1个回答
2
投票

第一个问题是nopython模式下的numba只接受(从版本0.43.1开始)np.random.seed: with an integer argument only

所以,不幸的是,你无法通过None


第二个问题是(据我所知)没有“单一”签名告诉numba如何处理缺失值,但是你可以使用两个签名(是的,它非常详细):

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}

@nb.jit(
    [nb.types.float64(nb.types.misc.Omitted(None)), 
     nb.types.float64(nb.types.int64)], 
    **jitkw)
def get_random(seed=None):
    return np.random.normal()

关于签名的两个部分的简短说明:

  • 如果省略参数,nb.types.float64(nb.types.misc.Omitted(None))告诉numba使用None作为默认类型
  • nb.types.float64(nb.types.int64)是期望整数的签名。

就个人而言,我不会指定签名,只是让numba弄明白。在numba中显式签名很少值得,而且更常见的是它们不会导致代码更慢且更不灵活。

© www.soinside.com 2019 - 2024. All rights reserved.