多个布尔条件相当于 `jax.lax.cond`

问题描述 投票:0回答:1

目前

jax.lax.cond
适用于一种布尔条件。有没有办法将其扩展到多个布尔条件?

作为示例,下面是一个不可追踪的函数:

def func(x):
    if x < 0: return x
    elif (x >= 0) & (x < 1): return 2*x
    else: return 3*x

如何在JAX中以可追踪的方式编写这个函数?

python conditional-statements jit jax
1个回答
0
投票

编写此类内容的一种紧凑方法是使用

jnp.select
:

import jax
import jax.numpy as jnp

@jax.jit
def func(x):
  return jnp.select([x < 0, x < 1], [x, 2 * x], default=3 * x)

x = jnp.array([-0.5, 0.5, 1.5])
print(func(x))
# [-0.5  1.   4.5]
© www.soinside.com 2019 - 2024. All rights reserved.