我正在尝试使用 JAX 运行 Colab 笔记本来生成图像,但遇到了以下错误:
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-7-73b0723cc3af> in <cell line: 23>()
21 import jax.numpy as jnp
22 import jax.scipy as jsp
---> 23 import jaxtorch
24 from jaxtorch import PRNG, Context, Module, nn, init
25 from tqdm import tqdm
3 frames
/content/./jax-guided-diffusion/jaxtorch/monkeypatches.py in register(**kwargs)
16 print(f'Not monkeypatching DeviceArray and Tracer with `{attr}`, because that method is already implemented.', file=sys.stderr)
17 continue
---> 18 setattr(jaxlib.xla_extension.DeviceArrayBase, attr, fun)
19 setattr(jax.interpreters.xla.DeviceArray, attr, fun)
20 setattr(jax.core.Tracer, attr, fun)
AttributeError: module 'jaxlib.xla_extension' has no attribute 'DeviceArrayBase'
我尝试通过使用不同的 JAX 版本和 Colab 提供的每个 GPU 来解决这个问题,但找不到解决方案。我真的很乐意在这方面提供任何帮助!
笔记本链接---> 点击
DeviceArray
及相关类型已在 JAX v0.4.1 中弃用并删除(请参阅 Changelog)。您使用的 jaxtorch
版本似乎与更新的 JAX 版本不兼容。如果没有可用的新版本 jaxtorch
,我建议尝试将其与 JAX 版本 0.3.25 或更早版本一起使用。