我遇到过一种场景,将
jax.grad
应用于具有 jax.lax.switch
和复合布尔条件的函数会产生 jax.errors.TracerBoolConversionError
。重现此行为的最小程序如下:
from jax.lax import switch
import jax.numpy as jnp
from jax import grad
func_0 = lambda x: jnp.where(0. < x < 1., x, 0.)
func_1 = lambda x: jnp.where(0. < x < 1., x, 1.)
func_list = [func_0, func_1]
func = lambda index, x: switch(index, func_list, x)
df = grad(func, argnums=1)(1, 2.)
print(df)
错误如下:
Traceback (most recent call last):
File "***/grad_test.py", line 12, in <module>
df = grad(func, argnums=1)(1, 0.5)
File "***/grad_test.py", line 10, in <lambda>
func = lambda index, x: switch(index, func_list, x)
File "***/grad_test.py", line 5, in <lambda>
func_0 = lambda x: jnp.where(0 < x < 1., x, 0.)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function <lambda> at ***/grad_test.py:5 for switch. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
但是,如果将布尔条件更改为单个条件(例如,
x < 1
),则不会发生错误。我想知道这是否可能是一个错误,或者其他情况,应该如何更改原始程序。
您不能将链式不等式与 JAX 或 NumPy 数组一起使用。您应该编写
0 < x < 1
,而不是 (0 < x) & (x < 1)
(请注意,由于运算符优先级,此处的括号不是可选的)。