我在 python 中使用 jax,我想随机循环一些代码。这是稍后进行 jit 编译的函数的一部分。我下面有一个小例子,它可以解释我想要做什么。
num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
arr = []
for i in range(num_iters):
arr += [i*i]
这可以正常工作,没有任何错误,并在循环结束时给出
arr=[0,1,4]
(带有我们在 0
中使用的 PRNGKey
的固定种子)。
但是,如果这是即时编译函数的一部分:
@jax.jit
def do_stuff(start):
num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
arr = []
for i in range(num_iters):
arr += [i*i]
for value in arr:
start += value
return start
我在
TracerIntegerConversionError
上得到 num_iters
。该函数在没有 jit 装饰器的情况下也可以正常工作。如何让它与 jit 一起工作?我基本上只是想构造一个列表 arr
,其长度取决于随机数。或者,我也可以使用具有最大可能大小的列表,但随后我必须对其进行随机次数的循环。
更多背景
可以使用
numpy
随机数生成器使其不抛出错误:
@jax.jit
def do_stuff(start):
np_rng = np.random.default_rng()
num_iters = np_rng.integers(1, 10)
arr = []
for i in range(num_iters):
arr += [i*i]
for value in arr:
start += value
return start
然而,这不是我想要的。有一个 jax
rng
传递给我的函数,我希望用它来生成 num_iters
。否则, arr
始终具有相同的长度,因为 numpy
种子固定为 jit 编译时可用的长度,并且我总是得到相同的结果,没有任何随机性。但是,如果我使用 rng
键作为 numpy
的种子(如 np.random.default_rng(seed=rng[0])
),它会再次出现以下错误:
TypeError: SeedSequence expects int or sequence of ints for entropy not Traced<ShapedArray(uint32[])>with<DynamicJaxprTrace(level=1/0)>
Jax 在这种情况下会抱怨,因为您尝试将跟踪值用作静态整数。请参阅 https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError 了解更多信息。
作为一种可能的解决方案,您可以将
num_iters
作为参数传递给 do_stuff
,将其声明为静态并在外部创建键,如下所示:
import jax
from functools import partial
@partial(jax.jit, static_argnums=(1,))
def do_stuff(start, num_iters):
arr = []
for i in range(num_iters):
arr += [i*i]
for value in arr:
start += value
return start
key = jax.random.PRNGKey(238)
for _ in range(4):
key, _ = jax.random.split(key)
num_iters = int(jax.random.randint(key, (1,), 1, 10))
print(do_stuff(0, num_iters))
哪个打印:
5
0
140
30
我上面列出的链接中提出了其他替代解决方案。
我希望这有帮助!