我想知道这里是否有人知道如何让 FLAX LSTM 层在 2023 年工作。我已经尝试了实际 Flax 文档中的一些代码片段,例如:
https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html
并且,那里提供的第一个示例,
import flax.linen as nn
import jax
import jax.numpy as jnp
class LSTM(nn.Module):
features: int
@nn.compact
def __call__(self, x):
ScanLSTM = nn.scan(
nn.LSTMCell, variable_broadcast="params",
split_rngs={"params": False}, in_axes=1, out_axes=1)
lstm = ScanLSTM(self.features)
input_shape = x[:, 0].shape
carry = lstm.initialize_carry(jax.random.key(0), input_shape)
carry, x = lstm(carry, x)
return x
x = jnp.ones((4, 12, 7))
module = LSTM(features=32)
y, variables = module.init_with_output(jax.random.key(0), x)
抛出错误。我已经查找了其他示例,但似乎他们在 2023 年的某个时候更改了 API,所以我在网上找到的内容不再起作用。
简而言之,我正在寻找一个关于如何将时间序列传递到 FLAX 中的 LSTM 的简单示例。
谢谢您的帮助。
您提供的代码片段可以在最新版本的 flax(版本 0.7.4)上正确运行。如果您使用旧版本的亚麻,则应将
jax.random.key
更改为 jax.random.PRNGKey
。有关此 JAX PRNG 密钥更改的一些信息,请参阅JEP 9263:类型化密钥和可插入 PRNG。