我正在尝试使用 Jax 从头开始实现 Transformer 架构。我在训练中发现三个问题:
jax.disable_jit()
不会删除隐式 jit 编译。jax.nn.softmax
默认调用_softmax_deprecated
? _softmax_deprecated
:unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
如果需要的话我会附上代码供您参考:class SelfAttention(eqx.Module):
def __call__(self, query, key, value, mask):
scaled_dot_prod = query @ jnp.transpose(key, (0, 2, 1)) / jnp.sqrt(query.shape[-1])
scaled_dot_prod = mask + scaled_dot_prod
return (jax.nn.softmax(scaled_dot_prod) @ value)
def create_mask(arr):
return jnp.where(arr == 0, np.NINF, 0)
def loss(model, X, y, X_mask, y_mask, labels):
y_pred = jnp.log(predict(model, X, y, X_mask, y_mask))
y_pred = jnp.where(labels==0, 0, jnp.take(y_pred, labels, axis=-1))
count = jnp.count_nonzero(y_pred)
return -jnp.sum(y_pred)/count
with jax.disable_jit():
for e in range(EPOCHS):
total_loss = 0
num_batches = 0
total_tokens = 0
for i, (Xbt, ybt, labelbt) in enumerate(dataloader(Xtr, ytr, SEQ_LEN)):
total_tokens += len([token for seq in labelbt for token in list(filter(lambda x: x!=0, seq))])
Xbt, ybt, labelbt = [jnp.array(x) for x in (Xbt, ybt, labelbt)]
Xmask, ymask = [create_mask(x) for x in (Xbt, ybt)]
model, opt_state, batch_loss = step(model, opt_state, Xbt, ybt, Xmask, ymask, labelbt)
total_loss += batch_loss
num_batches += 1
if num_batches % 20 == 0:
print(f"Batches trained: {num_batches} | Avg. Batch loss: {total_loss/num_batches}")
epoch_loss = total_loss / num_batches
print(f"Epoch {e} | loss: {epoch_loss}")
错误:
def _softmax_deprecated(
478 x: ArrayLike,
479 axis: Optional[Union[int, tuple[int, ...]]] = -1,
480 where: Optional[ArrayLike] = None,
481 initial: Optional[ArrayLike] = None) -> Array:
482 x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
--> 483 unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
484 result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
485 if where is not None:
FloatingPointError: invalid value (nan) encountered in jit(sub)
跨越200批次训练后才遇到上述问题。我没有检查是否跳过发生错误的特定批次。也许我应该检查某些特定输入是否是此错误的原因。
但是我找不到上述3个问题的答案:(
回答您的问题:
不会删除隐式 jit 编译。jax.disable_jit()
如果这是真的,则这是一个错误,您应该在 JAX 问题跟踪器上报告它。从你的问题中不清楚是什么让你相信情况确实如此。
- 为什么 jax.nn.softmax 默认调用 _softmax_deprecated?
因为
_softmax_deprecated
是旧的默认算法,有一天它会被弃用,但弃用尚未发生。有关详细信息,请参阅https://github.com/google/jax/pull/15677。要使用较新的算法,您可以设置 jax_softmax_custom_jvp=True
配置。
- 我在 _softmax_deprecated 内的减法中遇到 NaN:
如果需要,我将附上代码供您参考:unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
您没有包含足够的代码来重现您的问题(下次,尝试添加一个最小可重现示例,以允许其他人无需猜测即可回答您的问题)。但值得设置
jax_softmax_custom_jvp=True
看看是否可以解决您的问题。上面链接的拉取请求有详细信息。