2023年如何使用FLAX LSTM

问题描述 投票:0回答:1

我想知道这里是否有人知道如何让 FLAX LSTM 层在 2023 年工作。我已经尝试了实际 Flax 文档中的一些代码片段,例如:

https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html

并且,那里提供的第一个示例,

import flax.linen as nn
import jax
import jax.numpy as jnp

class LSTM(nn.Module):
  features: int

  @nn.compact
  def __call__(self, x):
    ScanLSTM = nn.scan(
      nn.LSTMCell, variable_broadcast="params",
      split_rngs={"params": False}, in_axes=1, out_axes=1)

    lstm = ScanLSTM(self.features)
    input_shape =  x[:, 0].shape
    carry = lstm.initialize_carry(jax.random.key(0), input_shape)
    carry, x = lstm(carry, x)
    return x

x = jnp.ones((4, 12, 7))
module = LSTM(features=32)
y, variables = module.init_with_output(jax.random.key(0), x)

抛出错误。我已经查找了其他示例,但似乎他们在 2023 年的某个时候更改了 API,所以我在网上找到的内容不再起作用。

简而言之,我正在寻找一个关于如何将时间序列传递到 FLAX 中的 LSTM 的简单示例。

谢谢您的帮助。

python lstm jax flax
1个回答
0
投票

您提供的代码片段可以在最新版本的 flax(版本 0.7.4)上正确运行。如果您使用旧版本的亚麻,则应将

jax.random.key
更改为
jax.random.PRNGKey
。有关此 JAX PRNG 密钥更改的一些信息,请参阅JEP 9263:类型化密钥和可插入 PRNG

© www.soinside.com 2019 - 2024. All rights reserved.