Softmax 和 JIT 问题中 NaN 的出现

问题描述 投票:0回答:1

我正在尝试使用 Jax 从头开始实现 Transformer 架构。我在训练中发现三个问题:

  1. jax.disable_jit()
    不会删除隐式 jit 编译。
  2. 为什么
    jax.nn.softmax
    默认调用
    _softmax_deprecated
  3. 我在减法中遇到 NaN
     _softmax_deprecated
    unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
    如果需要的话我会附上代码供您参考:
class SelfAttention(eqx.Module):
    def __call__(self, query, key, value, mask):
        scaled_dot_prod = query @ jnp.transpose(key, (0, 2, 1)) / jnp.sqrt(query.shape[-1])
        scaled_dot_prod = mask + scaled_dot_prod
        return (jax.nn.softmax(scaled_dot_prod) @ value)

def create_mask(arr):
    return jnp.where(arr == 0, np.NINF, 0)

def loss(model, X, y, X_mask, y_mask, labels):
    y_pred = jnp.log(predict(model, X, y, X_mask, y_mask))
    y_pred = jnp.where(labels==0, 0, jnp.take(y_pred, labels, axis=-1))
    count = jnp.count_nonzero(y_pred)
    return -jnp.sum(y_pred)/count

with jax.disable_jit():
    for e in range(EPOCHS):
        total_loss = 0
        num_batches = 0
        total_tokens = 0
        for i, (Xbt, ybt, labelbt) in enumerate(dataloader(Xtr, ytr, SEQ_LEN)):
            total_tokens += len([token for seq in labelbt for token in list(filter(lambda x: x!=0, seq))])
            Xbt, ybt, labelbt = [jnp.array(x) for x in (Xbt, ybt, labelbt)]
            Xmask, ymask = [create_mask(x) for x in (Xbt, ybt)]

            model, opt_state, batch_loss = step(model, opt_state, Xbt, ybt, Xmask, ymask, labelbt)
            total_loss += batch_loss
            num_batches += 1

            if num_batches % 20 == 0:
                print(f"Batches trained: {num_batches} | Avg. Batch loss: {total_loss/num_batches}")

        epoch_loss = total_loss / num_batches
        print(f"Epoch {e} | loss: {epoch_loss}")

错误:

def _softmax_deprecated(
    478     x: ArrayLike,
    479     axis: Optional[Union[int, tuple[int, ...]]] = -1,
    480     where: Optional[ArrayLike] = None,
    481     initial: Optional[ArrayLike] = None) -> Array:
    482   x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
--> 483   unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
    484   result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
    485   if where is not None:

FloatingPointError: invalid value (nan) encountered in jit(sub)

跨越200批次训练后才遇到上述问题。我没有检查是否跳过发生错误的特定批次。也许我应该检查某些特定输入是否是此错误的原因。

但是我找不到上述3个问题的答案:(

python nan jit jax
1个回答
0
投票

回答您的问题:

  1. jax.disable_jit()
    不会删除隐式 jit 编译。

如果这是真的,则这是一个错误,您应该在 JAX 问题跟踪器上报告它。从你的问题中不清楚是什么让你相信情况确实如此。

  1. 为什么 jax.nn.softmax 默认调用 _softmax_deprecated?

因为

_softmax_deprecated
是旧的默认算法,有一天它会被弃用,但弃用尚未发生。有关详细信息,请参阅https://github.com/google/jax/pull/15677。要使用较新的算法,您可以设置
jax_softmax_custom_jvp=True
配置。

  1. 我在 _softmax_deprecated 内的减法中遇到 NaN:
    unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
    如果需要,我将附上代码供您参考:

您没有包含足够的代码来重现您的问题(下次,尝试添加一个最小可重现示例,以允许其他人无需猜测即可回答您的问题)。但值得设置

jax_softmax_custom_jvp=True
看看是否可以解决您的问题。上面链接的拉取请求有详细信息。

© www.soinside.com 2019 - 2024. All rights reserved.