如何使用jit编译在jax中循环随机次数?

问题描述 投票:0回答:1

我在 python 中使用 jax,我想随机循环一些代码。这是稍后进行 jit 编译的函数的一部分。我下面有一个小例子,它可以解释我想要做什么。

num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
arr = []
for i in range(num_iters):
  arr += [i*i]

这可以正常工作,没有任何错误,并在循环结束时给出

arr=[0,1,4]
(带有我们在
0
中使用的
PRNGKey
的固定种子)。

但是,如果这是即时编译函数的一部分:

@jax.jit
def do_stuff(start):
  num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
  arr = []
  for i in range(num_iters):
    arr += [i*i]
  for value in arr:
    start += value
  return start

我在

TracerIntegerConversionError
上得到
num_iters
。该函数在没有 jit 装饰器的情况下也可以正常工作。如何让它与 jit 一起工作?我基本上只是想构造一个列表
arr
,其长度取决于随机数。或者,我也可以使用具有最大可能大小的列表,但随后我必须对其进行随机次数的循环。


更多背景

可以使用

numpy
随机数生成器使其不抛出错误:

@jax.jit
def do_stuff(start):
  np_rng = np.random.default_rng()
  num_iters = np_rng.integers(1, 10)
  arr = []
  for i in range(num_iters):
    arr += [i*i]
  for value in arr:
    start += value
  return start

然而,这不是我想要的。有一个 jax

rng
传递给我的函数,我希望用它来生成
num_iters
。否则,
arr
始终具有相同的长度,因为
numpy
种子固定为 jit 编译时可用的长度,并且我总是得到相同的结果,没有任何随机性。但是,如果我使用
rng
键作为
numpy
的种子(如
np.random.default_rng(seed=rng[0])
),它会再次出现以下错误:

TypeError: SeedSequence expects int or sequence of ints for entropy not Traced<ShapedArray(uint32[])>with<DynamicJaxprTrace(level=1/0)>
python random jit jax
1个回答
0
投票

Jax 在这种情况下会抱怨,因为您尝试将跟踪值用作静态整数。请参阅 https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError 了解更多信息。

作为一种可能的解决方案,您可以将

num_iters
作为参数传递给
do_stuff
,将其声明为静态并在外部创建键,如下所示:

import jax
from functools import partial

@partial(jax.jit, static_argnums=(1,))
def do_stuff(start, num_iters):    
  arr = []
  
  for i in range(num_iters):
    arr += [i*i]
  
  for value in arr:
    start += value
  
  return start

key = jax.random.PRNGKey(238)

for _ in range(4):
  key, _ = jax.random.split(key)
  num_iters = int(jax.random.randint(key, (1,), 1, 10))
  print(do_stuff(0, num_iters))

哪个打印:

5
0
140
30

我上面列出的链接中提出了其他替代解决方案。

我希望这有帮助!

© www.soinside.com 2019 - 2024. All rights reserved.