stable_baselines3 由于 dummy_vec_env.py 中的错误,PPO 模型在训练期间崩溃

问题描述 投票:0回答:1

我正在尝试在 CartPole-v1 环境中训练 PPO 模型。

import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

env_id = "CartPole-v1"
#Making the environment
envs = make_vec_env(env_id, n_envs= 4)
envs = VecNormalize(envs)

#Training the model
model = PPO(policy="MlpPolicy", env=envs, verbose=1)
model.learn(1000)
model.save("CartPole-v1-model")
envs.save("CartPole-v1-env")

我收到此错误消息:

Error message

我只安装了带 cpu 的 pytorch,我怀疑这是错误的原因。但是,dummy_vec_env.py 和父 base_vec_env.py 的源代码根本不导入 pytorch,所以我不确定 pytorch 是原因。

我复制了代码并成功让它在 HuggingFace google colab notebook 中运行https://colab.research.google.com/github/huggingface/deep-rl-class/blob/master/notebooks/unit1/ unit1.ipynb 所以我非常困惑为什么它在我的本地机器上不起作用。

我检查了 dummy_vec_env.py 的调试器,我在变量 obs 中得到了一个元组。 debugger

任何帮助将不胜感激!

python pytorch openai-gym stable-baselines
1个回答
0
投票

问题是conda的stable_baselines3版本问题。我的 stable_baselines3 是版本 1.1.0.

安装更高版本的 stable_baselines3 using pip 解决了问题。 我用过

pip install stable-baselines3==2.0.0a5

注意:我安装了 2.0.0a5 以跟随 HuggingFace Google Collab 页面,但有更高版本的 stable_baselines3 很可能也可以工作。

© www.soinside.com 2019 - 2024. All rights reserved.