带有 `jax.lax.switch` 和复合布尔条件的函数的 JAX `grad` 错误

问题描述 投票:0回答:1

我遇到过一种场景,将

jax.grad
应用于具有
jax.lax.switch
和复合布尔条件的函数会产生
jax.errors.TracerBoolConversionError
。重现此行为的最小程序如下:

from jax.lax import switch
import jax.numpy as jnp
from jax import grad

func_0 = lambda x: jnp.where(0. < x < 1., x, 0.)
func_1 = lambda x: jnp.where(0. < x < 1., x, 1.)

func_list = [func_0, func_1]

func = lambda index, x: switch(index, func_list, x)

df = grad(func, argnums=1)(1, 2.)
print(df)

错误如下:

Traceback (most recent call last):
  File "***/grad_test.py", line 12, in <module>
    df = grad(func, argnums=1)(1, 0.5)
  File "***/grad_test.py", line 10, in <lambda>
    func = lambda index, x: switch(index, func_list, x)
  File "***/grad_test.py", line 5, in <lambda>
    func_0 = lambda x: jnp.where(0 < x < 1., x, 0.)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function <lambda> at ***/grad_test.py:5 for switch. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

但是,如果将布尔条件更改为单个条件(例如,

x < 1
),则不会发生错误。我想知道这是否可能是一个错误,或者其他情况,应该如何更改原始程序。

python boolean gradient jax
1个回答
0
投票

您不能将链式不等式与 JAX 或 NumPy 数组一起使用。您应该编写

0 < x < 1
,而不是
(0 < x) & (x < 1)
(请注意,由于运算符优先级,此处的括号不是可选的)。

© www.soinside.com 2019 - 2024. All rights reserved.