下面是一个示例,其中具有自定义向量雅可比积 (
custom_vjp
) 的函数被 vmap
ped。对于像这样的简单函数,调用 vjp
会失败:
@partial(custom_vjp, nondiff_argnums=(0,))
def test_func(f: Callable[..., float],
R: Array
) -> float:
return f(jnp.dot(R, R))
def test_func_fwd(f, primal):
primal_out = test_func(f, primal)
residual = 2. * primal * primal_out
return primal_out, residual
def test_func_bwd(f, residual, cotangent):
cotangent_out = residual * cotangent
return (cotangent_out, )
test_func.defvjp(test_func_fwd, test_func_bwd)
test_func = vmap(test_func, in_axes=(None, 0))
if __name__ == "__main__":
def f(x):
return x
# vjp
primal, f_vjp = vjp(partial(test_func, f),
jnp.ones((10, 3))
)
cotangent = jnp.ones(10)
cotangent_out = f_vjp(cotangent)
print(cotangent_out[0].shape)
错误消息显示:
ValueError: Shape of cotangent input to vjp pullback function (10,) must be the same as the shape of corresponding primal input (10, 3).
在这里,我认为错误消息具有误导性,因为余切输入应具有与原始输出相同的形状,在本例中应为
(10, )
。不过,我不清楚为什么会发生这个错误。
问题在于,在
test_func_fwd
中,您递归地调用 test_func
,但您已使用其 vmapped 版本覆盖了全局命名空间中的 test_func
。如果您在全局命名空间中保留原始 test_func
不变,您的代码将按预期工作:
...
test_func_mapped = vmap(test_func, in_axes=(None, 0))
...
primal, f_vjp = vjp(partial(test_func_mapped, f),
jnp.ones((10, 3))
)