jax:我们如何解决错误:pmap 被要求沿轴 0 映射其参数,这意味着它的等级应该至少为 1,但实际上仅为 0?

问题描述 投票:0回答:1

我正在尝试运行这个基于分数的生成建模的简单介绍。代码使用的是

flax.optim
,同时似乎已移至
optax
https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html)。

我已经制作了 Colab 代码的副本,其中包含我认为需要进行的更改(我只是不确定需要如何替换

optimizer = flax.jax_utils.replicate(optimizer)
)。

现在,在训练部分,我收到错误

pmap 被要求沿轴 0 映射其参数,这意味着它的等级至少应为 1,但实际上仅为 0(其形状为 ())

loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state)
线上。这显然来自“定义损失函数”部分中的
return jax.pmap(step_fn, axis_name='device')

如何修复此错误?我用谷歌搜索过,但不知道这里出了什么问题。

python neural-network jax pmap
1个回答
0
投票

发生这种情况是因为您将标量参数传递给 pmap 函数。例如:

import jax
func = lambda x: x ** 2
pfunc = jax.pmap(func)

pfunc(1.0)
# ValueError: pmap was requested to map its argument along axis 0, which implies
# that its rank should be at least 1, but is only 0 (its shape is ())

如果你想对标量进行操作,你应该使用该函数而不将其包装在

pmap
:

func(1.0)
# 1.0

或者,如果您想使用

pmap
,您应该对前导维度与设备数量匹配的数组进行操作:

num_devices = len(jax.devices())
x = jax.numpy.arange(num_devices)
pfunc(x)
# Array([ 0,  1,  4,  9, 16, 25, 36, 49], dtype=int32)
© www.soinside.com 2019 - 2024. All rights reserved.