我正在尝试运行用 jax 编写的模型,https://github.com/lindermanlab/S5。但是,我遇到了一些错误,上面写着
Traceback (most recent call last):
File "/Path/run_train.py", line 101, in <module>
train(parser.parse_args())
File "/Path/train.py", line 144, in train
state = create_train_state(model_cls,
File "/Path/train_helpers.py", line 135, in create_train_state
params = variables["params"].unfreeze()
AttributeError: 'dict' object has no attribute 'unfreeze'
我尝试通过
复制此错误import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
model = nn.Dense(features=3)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
params_unfrozen = flax.traverse_util.unfreeze(params)
错误显示:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'
我正在使用:
flax 0.7.4
jax 0.4.13
jaxlib 0.4.13+cuda12.cudnn89
我认为这是与亚麻版本有关的问题,但有人知道到底发生了什么吗?任何帮助表示赞赏。如果您需要任何进一步的信息,请告诉我
unfreeze
是 Flax FrozenDict
类的方法:(参见 FrozenDict.unfreeze
)。您似乎已通过了 Python dict
,而预期为 FrozenDict
。
要解决此问题,您应该确保
variables['params']
是 FrozenDict
,而不是 dict
。
关于您尝试复制中的错误:
flax.traverse_util
未定义unfreeze
函数,但这似乎与原始问题无关。