jit JAX 函数中的迭代器

问题描述 投票:0回答:1

我是 JAX 新手,阅读文档时我发现 jitted 函数不应包含迭代器(有关纯函数的部分

他们带来了这个例子:

import jax.numpy as jnp
import jax.lax as lax
from jax import jit

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

尝试稍微摆弄它,看看我是否可以直接得到错误,而不是我写的未定义行为

@jit
def f(x, arr):
    for i in range(10):
        x += arr[i]
    return x

@jit
def f1(x, arr):
    it = iter(arr)
    for i in range(10):
        x += next(it)
    return x

print(f(0,array)) # 45 as expected
print(f1(0,array)) # still 45 

jitted 函数 f1() 现在显示正确的行为是“机会”吗?

python jit jax
1个回答
1
投票

您的代码之所以有效,是因为 JAX 跟踪模型的工作方式。当 JAX 的跟踪遇到 Python 控制流(如

for
循环)时,循环会在跟踪时进行全面评估(JAX Sharp Bits:控制流 中对此有一些探索)。

因此,在这种情况下使用迭代器是没有问题的,因为每次迭代都会在跟踪时进行评估,因此

next(it)
在每次迭代时都会重新评估。

相反,当使用

lax.fori_loop
时,
next(iterator)
仅执行一次,其输出被视为跟踪时间常量,在运行时迭代期间不会改变。

© www.soinside.com 2019 - 2024. All rights reserved.