获得一个非常简单的 stablebaselines3 示例来工作

问题描述 投票:0回答:1

我尝试对最简单的抛硬币游戏进行建模,您必须预测它是否会成为正面。遗憾的是它不会运行,给我:

Using cpu device
Traceback (most recent call last):
  File "/home/user/python/simplegame.py", line 40, in <module>
    model.learn(total_timesteps=10000)
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/ppo/ppo.py", line 315, in learn
    return super().learn(
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 264, in learn
    total_timesteps, callback = self._setup_learn(
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/base_class.py", line 423, in _setup_learn
    self._last_obs = self.env.reset()  # type: ignore[assignment]
  File "/home/user/python/mypython3.10/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
    obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
TypeError: CoinFlipEnv.reset() got an unexpected keyword argument 'seed'

这是代码:

import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

class CoinFlipEnv(gym.Env):
    def __init__(self, heads_probability=0.8):
        super(CoinFlipEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(2)  # 0 for heads, 1 for tails
        self.observation_space = gym.spaces.Discrete(2)  # 0 for heads, 1 for tails
        self.heads_probability = heads_probability
        self.flip_result = None

    def reset(self):
        # Reset the environment
        self.flip_result = None
        return self._get_observation()

    def step(self, action):
        # Perform the action (0 for heads, 1 for tails)
        self.flip_result = int(np.random.rand() < self.heads_probability)

        # Compute the reward (1 for correct prediction, -1 for incorrect)
        reward = 1 if self.flip_result == action else -1

        # Return the observation, reward, done, and info
        return self._get_observation(), reward, True, {}

    def _get_observation(self):
        # Return the current coin flip result
        return self.flip_result

# Create the environment with heads probability of 0.8
env = DummyVecEnv([lambda: CoinFlipEnv(heads_probability=0.8)])

# Create the PPO model
model = PPO("MlpPolicy", env, verbose=1)

# Train the model
model.learn(total_timesteps=10000)

# Save the model
model.save("coin_flip_model")

# Evaluate the model
obs = env.reset()
for _ in range(10):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    print(f"Action: {action}, Observation: {obs}, Reward: {rewards}")

我做错了什么?

这是2.2.1版本。

python stable-baselines
1个回答
0
投票

问题是您的 CoinFlipEnv 类不符合gym.Env 接口,特别是重置功能。根据此处提供的文档https://gymnasium.farama.org/api/env/#gymnasium.Env.reset重置函数采用种子参数,因此您的函数:

def reset(self)

也必须采用这样的一个(它还表示它也采用 options 关键字参数,所以我们也将其包括在内):

def reset(self, seed=None, options=None)

这就是堆栈跟踪的最后一行告诉您的内容:

obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)

对象 self.envs[env_idx] 是类 CoinFlipEnv 的实例,它试图将 Seed=self._seeds[env_idx] 和 **maybe_options 传递给 CoinFlipEnv 的重置函数,但因为它试图传递 Seed 的命名参数对于没有定义种子的函数,它会引发错误。

© www.soinside.com 2019 - 2024. All rights reserved.