JAX `vjp` 对于带有 `custom_vjp` 的 vmapped 函数失败

问题描述 投票:0回答:1

下面是一个示例,其中具有自定义向量雅可比积 (

custom_vjp
) 的函数被
vmap
ped。对于像这样的简单函数,调用
vjp
会失败:

@partial(custom_vjp, nondiff_argnums=(0,))
def test_func(f: Callable[..., float],
              R: Array
              ) -> float:

    return f(jnp.dot(R, R))


def test_func_fwd(f, primal):

    primal_out = test_func(f, primal)
    residual = 2. * primal * primal_out
    return primal_out, residual


def test_func_bwd(f, residual, cotangent):

    cotangent_out = residual * cotangent
    return (cotangent_out, )


test_func.defvjp(test_func_fwd, test_func_bwd)

test_func = vmap(test_func, in_axes=(None, 0))


if __name__ == "__main__":

    def f(x):
        return x

    # vjp
    primal, f_vjp = vjp(partial(test_func, f),
                        jnp.ones((10, 3))
                        )

    cotangent = jnp.ones(10)
    cotangent_out = f_vjp(cotangent)

    print(cotangent_out[0].shape)

错误消息显示:

ValueError: Shape of cotangent input to vjp pullback function (10,) must be the same as the shape of corresponding primal input (10, 3).

在这里,我认为错误消息具有误导性,因为余切输入应具有与原始输出相同的形状,在本例中应为

(10, )
。不过,我不清楚为什么会发生这个错误。

python vectorization jax automatic-differentiation
1个回答
0
投票

问题在于,在

test_func_fwd
中,您递归地调用
test_func
,但您已使用其 vmapped 版本覆盖了全局命名空间中的
test_func
。如果您在全局命名空间中保留原始
test_func
不变,您的代码将按预期工作:

...

test_func_mapped = vmap(test_func, in_axes=(None, 0))

... 

primal, f_vjp = vjp(partial(test_func_mapped, f),
                    jnp.ones((10, 3))
                    )
© www.soinside.com 2019 - 2024. All rights reserved.