我正在使用 Flax 训练神经网络。我的训练数据的输出中有大量的 nan。我想忽略这些,只使用非纳米值进行训练。为了实现这一目标,我尝试使用
jnp.nanmean
来计算损失,即:
def nanloss(params, inputs, targets):
pred = model.apply(params, inputs)
return jnp.nanmean((pred - targets) ** 2)
def train_step(state, inputs, targets):
loss, grads = jax.value_and_grad(nanloss)(state.params, inputs, targets)
state = state.apply_gradients(grads=grads)
return state, loss
然而,经过一个训练步骤后,损失为 nan。
我想要实现的目标可能吗?如果是这样,我该如何解决这个问题?
我怀疑您遇到了这里讨论的问题:JAX FAQ:梯度包含 NaN,其中使用
where
。如果这确实是问题所在,您可以通过在计算损失之前过滤值来解决此问题;例如这样:
def nanloss(params, inputs, targets):
pred = model.apply(params, inputs)
mask = jnp.isnan(pred) | jnp.isnan(targets)
pred = jnp.where(mask, 0, pred)
targets = jnp.where(mask, 0, targets)
return jnp.mean((pred - targets) ** 2, where=~mask)