输出中带有 nan 的 Flax 神经网络

问题描述 投票:0回答:1

我正在使用 Flax 训练神经网络。我的训练数据的输出中有大量的 nan。我想忽略这些,只使用非纳米值进行训练。为了实现这一目标,我尝试使用

jnp.nanmean
来计算损失,即:

def nanloss(params, inputs, targets):
    pred = model.apply(params, inputs)
    return jnp.nanmean((pred - targets) ** 2)

def train_step(state, inputs, targets):
    loss, grads = jax.value_and_grad(nanloss)(state.params, inputs, targets)
    state = state.apply_gradients(grads=grads)
    return state, loss

然而,经过一个训练步骤后,损失为 nan。

我想要实现的目标可能吗?如果是这样,我该如何解决这个问题?

python tensorflow deep-learning jax flax
1个回答
0
投票

我怀疑您遇到了这里讨论的问题:JAX FAQ:梯度包含 NaN,其中使用

where
。如果这确实是问题所在,您可以通过在计算损失之前过滤值来解决此问题;例如这样:

def nanloss(params, inputs, targets):
    pred = model.apply(params, inputs)
    mask = jnp.isnan(pred) | jnp.isnan(targets)
    pred = jnp.where(mask, 0, pred)
    targets = jnp.where(mask, 0, targets)
    return jnp.mean((pred - targets) ** 2, where=~mask)
© www.soinside.com 2019 - 2024. All rights reserved.