我正在尝试在 CartPole-v1 环境中训练 PPO 模型。
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
env_id = "CartPole-v1"
#Making the environment
envs = make_vec_env(env_id, n_envs= 4)
envs = VecNormalize(envs)
#Training the model
model = PPO(policy="MlpPolicy", env=envs, verbose=1)
model.learn(1000)
model.save("CartPole-v1-model")
envs.save("CartPole-v1-env")
我收到此错误消息:
我只安装了带 cpu 的 pytorch,我怀疑这是错误的原因。但是,dummy_vec_env.py 和父 base_vec_env.py 的源代码根本不导入 pytorch,所以我不确定 pytorch 是原因。
我复制了代码并成功让它在 HuggingFace google colab notebook 中运行https://colab.research.google.com/github/huggingface/deep-rl-class/blob/master/notebooks/unit1/ unit1.ipynb 所以我非常困惑为什么它在我的本地机器上不起作用。
我检查了 dummy_vec_env.py 的调试器,我在变量 obs 中得到了一个元组。
任何帮助将不胜感激!
问题是conda的stable_baselines3版本问题。我的 stable_baselines3 是版本 1.1.0.
安装更高版本的 stable_baselines3 using pip 解决了问题。 我用过
pip install stable-baselines3==2.0.0a5
注意:我安装了 2.0.0a5 以跟随 HuggingFace Google Collab 页面,但有更高版本的 stable_baselines3 很可能也可以工作。