`jax.grad`的矢量化能力

问题描述 投票:0回答:1

我正在尝试对以下“梯度幂”函数进行矢量化,以便它接受多个

order
:(参见此处

def grad_pow(f, order, argnum):

    for i in jnp.arange(order):
        f = grad(f, argnums=argnum)

    return f

在参数

vmap
上应用
order
后,此函数会产生以下错误:

jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
It arose in the jnp.arange argument 'stop'

我尝试使用

grad_pow
jax.lax.cond
编写
jax.lax.scan
的静态版本,遵循逻辑 here:

def static_grad_pow(f, order, argnum):

    order_max = 3  ## maximum order

    def grad_pow(f, i):
        return cond(i <= order, grad(f, argnum), f), None

    return scan(grad_pow, f, jnp.arange(order_max+1))[0]


if __name__ == "__main__":

    test_func = lambda x: jnp.exp(-2*x)
    test_func_grad_pow = static_grad_pow(jax.tree_util.Partial(test_func), 1, 0)
    print(test_func_grad_pow(1.))

尽管如此,这个解决方案仍然会产生错误:

    return cond(i <= order, grad(f, argnum), f), None
TypeError: differentiating with respect to argnums=0 requires at least 1 positional arguments to be passed by the caller, but got only 0 positional arguments.

只是想知道如何解决这个问题?

python loops vectorization jax
1个回答
0
投票

您的问题的根本问题是 vmapped 函数不能返回函数,它只能返回数组。除了所有其他细节之外,这排除了编写执行您想要的功能的有效函数的任何可能性。

还有其他选择:例如,您可以创建一个接受参数并将该函数应用于这些参数的函数,而不是尝试创建一个返回函数的函数。

在这种情况下,您将遇到另一个问题:如果追踪到

n
,则无法应用
grad
n
次。像
grad
这样的 JAX 转换是在跟踪时评估的,并且像
n
这样的跟踪值直到运行时才可用。解决此问题的一种方法是预先定义您感兴趣的所有函数,并在运行时使用
lax.switch
在它们之间进行选择。结果看起来像这样:

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.jit, static_argnums=[0], static_argnames=['argnum', 'max_order'])
def apply_multi_grad(f, order, *args, argnum=0, max_order=10):
  funcs = [f]
  for i in range(max_order):
    funcs.append(jax.grad(funcs[-1], argnum))
  return jax.lax.switch(order, funcs, *args)


order = jnp.arange(3)
x = jnp.ones(3)
f = jnp.sin

print(jax.vmap(apply_multi_grad, in_axes=(None, 0, 0))(f, order, x))
# [ 0.84147096  0.5403023  -0.84147096]

# Compare by doing it manually:
print(jnp.array([f(x[0]), jax.grad(f)(x[1]), jax.grad(jax.grad(f))(x[2])]))
# [ 0.84147096  0.5403023  -0.84147096]
© www.soinside.com 2019 - 2024. All rights reserved.