JAX `vmap` 对于多个参数的意外行为

问题描述 投票:0回答:1

我发现 JAX 中的

vmap
在应用于多个参数时不会按预期运行。例如,考虑下面的函数:

def f1(x, y, z):
    f = x[:, None, None] * z[None, None, :] + y[None, :, None]
    return f

对于

x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3)
,该函数的输出具有形状
(7, 5, 3)
。但是,对于以下 vmap 版本:

@partial(vmap, in_axes=(None, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
    f = x*z + y
    return f

它输出此错误:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 5: axis 0 of argument y of type int32[5];
  * one axis had size 3: axis 0 of argument z of type int32[3]

有人可以解释一下这个错误背后的原因吗?

python arrays vectorization jax
1个回答
0
投票

vmap
的语义是它沿着一个或多个数组执行单个批处理操作。当您指定
in_axes=(None, 0, 0)
时,含义是“同时沿
y
z
的主尺寸映射”:您看到的错误告诉您
y
z
的主尺寸有不同尺寸,因此它们不兼容批处理。

您的函数

f1
本质上使用广播来编码三个批处理操作,因此要使用
vmap
复制该逻辑,您将需要三个
vmap
应用程序。您可以这样表达:

@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
    f = x*z + y
    return f
© www.soinside.com 2019 - 2024. All rights reserved.