无法从“jax”导入名称“linear_util”

问题描述 投票:0回答:1

我正在尝试重现S5模型的实验,https://github.com/lindermanlab/S5,但是在解决环境时遇到了一些问题。当我运行 shell 脚本

./run_lra_cifar.sh
时,出现以下错误

Traceback (most recent call last):
  File "/Path/S5/run_train.py", line 3, in <module>
    from s5.train import train
  File "/Path/S5/s5/train.py", line 7, in <module>
    from .train_helpers import create_train_state, reduce_lr_on_plateau,\
  File "/Path/train_helpers.py", line 6, in <module>
    from flax.training import train_state
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
    from . import core
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/axes_scan.py", line 22, in <module>
    from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (/Path/miniconda3/lib/python3.12/site-packages/jax/__init__.py)

我在 RTX4090 上运行它,我的 CUDA 版本是 11.8。我的jax版本是0.4.25,jaxlib版本是0.4.25+cuda11.cudnn86

我首先尝试使用作者的

安装依赖项
pip install -r requirements_gpu.txt

但是,这对我来说似乎不起作用,因为我什至不能

import jax
。所以我按照https://jax.readthedocs.io/en/latest/installation.html上的说明安装了jax 通过输入

pip install --upgrade pip
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

到目前为止我已经尝试过:

  1. 使用较旧的 GPU(3060 和 2070)
  2. 将 python 降级到 3.9

有人知道可能出了什么问题吗?任何帮助表示赞赏

machine-learning deep-learning gpu jax state-space
1个回答
0
投票

jax.linear_util
在 JAX v0.4.16 中已被 弃用,并且 在 JAX v0.4.24 中被删除

要解决您的问题,您需要安装仍具有

jax.linear_util

 的旧版本 JAX,或者更新到与最新 JAX 版本兼容的新版本 
flax

© www.soinside.com 2019 - 2024. All rights reserved.