我是 JAX 新手,阅读文档时我发现 jitted 函数不应包含迭代器(有关纯函数的部分)
他们带来了这个例子:
import jax.numpy as jnp
import jax.lax as lax
from jax import jit
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
尝试稍微摆弄它,看看我是否可以直接得到错误,而不是我写的未定义行为
@jit
def f(x, arr):
for i in range(10):
x += arr[i]
return x
@jit
def f1(x, arr):
it = iter(arr)
for i in range(10):
x += next(it)
return x
print(f(0,array)) # 45 as expected
print(f1(0,array)) # still 45
jitted 函数 f1() 现在显示正确的行为是“机会”吗?
您的代码之所以有效,是因为 JAX 跟踪模型的工作方式。当 JAX 的跟踪遇到 Python 控制流(如
for
循环)时,循环会在跟踪时进行全面评估(JAX Sharp Bits:控制流 中对此有一些探索)。
因此,在这种情况下使用迭代器是没有问题的,因为每次迭代都会在跟踪时进行评估,因此
next(it)
在每次迭代时都会重新评估。
相反,当使用
lax.fori_loop
时,next(iterator)
仅执行一次,其输出被视为跟踪时间常量,在运行时迭代期间不会改变。