AttributeError:模块“flax.traverse_util”没有属性“unfreeze”

问题描述 投票:0回答:1

我正在尝试运行用 jax 编写的模型,https://github.com/lindermanlab/S5。但是,我遇到了一些错误,上面写着

   Traceback (most recent call last):
  File "/Path/run_train.py", line 101, in <module>
    train(parser.parse_args())
  File "/Path/train.py", line 144, in train
    state = create_train_state(model_cls,
  File "/Path/train_helpers.py", line 135, in create_train_state
    params = variables["params"].unfreeze()
AttributeError: 'dict' object has no attribute 'unfreeze'

我尝试通过

复制此错误
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn

model = nn.Dense(features=3)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
params_unfrozen = flax.traverse_util.unfreeze(params)

错误显示:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'

我正在使用:

flax 0.7.4
jax 0.4.13
jaxlib 0.4.13+cuda12.cudnn89

我认为这是与亚麻版本有关的问题,但有人知道到底发生了什么吗?任何帮助表示赞赏。如果您需要任何进一步的信息,请告诉我

attributeerror jax flax
1个回答
0
投票

unfreeze
是 Flax
FrozenDict
类的方法:(参见
FrozenDict.unfreeze
)。您似乎已通过了 Python
dict
,而预期为
FrozenDict

要解决此问题,您应该确保

variables['params']
FrozenDict
,而不是
dict

关于您尝试复制中的错误:

flax.traverse_util
未定义
unfreeze
函数,但这似乎与原始问题无关。

© www.soinside.com 2019 - 2024. All rights reserved.