我正在尝试重现S5模型的实验,https://github.com/lindermanlab/S5,但是在解决环境时遇到了一些问题。当我运行 shell 脚本
./run_lra_cifar.sh
时,出现以下错误
Traceback (most recent call last):
File "/Path/S5/run_train.py", line 3, in <module>
from s5.train import train
File "/Path/S5/s5/train.py", line 7, in <module>
from .train_helpers import create_train_state, reduce_lr_on_plateau,\
File "/Path/train_helpers.py", line 6, in <module>
from flax.training import train_state
File "/Path/miniconda3/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
from . import core
File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/__init__.py", line 15, in <module>
from .axes_scan import broadcast
File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/axes_scan.py", line 22, in <module>
from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (/Path/miniconda3/lib/python3.12/site-packages/jax/__init__.py)
我在 RTX4090 上运行它,我的 CUDA 版本是 11.8。我的jax版本是0.4.25,jaxlib版本是0.4.25+cuda11.cudnn86
我首先尝试使用作者的
安装依赖项pip install -r requirements_gpu.txt
但是,这对我来说似乎不起作用,因为我什至不能
import jax
。所以我按照https://jax.readthedocs.io/en/latest/installation.html上的说明安装了jax
通过输入
pip install --upgrade pip
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
到目前为止我已经尝试过:
有人知道可能出了什么问题吗?任何帮助表示赞赏