获取帧作为 CartPole 环境的观察

问题描述 投票:0回答:1

在 Python 中,我使用

stablebaselines3
gymnasium
来实现自定义 DQN。使用 atari 游戏我测试了代理并正常工作,现在我还需要在像
CartPole
这样的环境中测试它 问题是这种环境不会返回帧作为观察,而是仅返回一个向量。 因此,我需要一种方法来返回 CartPole 帧作为观察,并应用与 Atari 游戏相同的预处理内容(例如将游戏的 4 帧堆叠在一起)

我在互联网上搜索了如何做到这一点,经过一番尝试后我想出了这个代码,但是我遇到了一些问题。

这是代码:

from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
import numpy as np
import cv2


class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self, env):
        super(CartPoleImageWrapper, self).__init__(env)
        self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def _get_image_observation(self):
        # Render the CartPole environment
        cartpole_image = self.render()

        # Resize the image to 84x84 pixels
        resized_image = cv2.resize(cartpole_image, (84, 84))
        # make it grayscale
        resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
        resized_image = np.expand_dims(resized_image, axis=-1)
        return resized_image

    def reset(self):
        self.env.reset()
        return self._get_image_observation()

    def step(self, action):
        observation, reward, terminated, info = self.env.step(action)
        return self._get_image_observation(), reward, terminated, info


env = CartPoleImageWrapper(CartPoleEnv(render_mode='rgb_array'))
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
#exit()
    
vec_env.close()

当我打电话

env.reset()
时,错误是这样的:

Traceback (most recent call last):
    File "/data/g.carfi/rl/tmp.py", line 41, in <module>
        obs = vec_env.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 41, in reset
        observation = self.venv.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 113, in reset
        observations = self.venv.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
        obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/monitor.py", line 83, in reset
        return self.env.reset(**kwargs)
    TypeError: reset() got an unexpected keyword argument 'seed'

我该如何解决这个问题?

python reinforcement-learning openai-gym atari-2600
1个回答
0
投票

您遇到的问题是由于 CartPoleEnv 类的 reset() 方法不接受种子参数,但它似乎是由 VecEnv 内部传递的。

要解决此问题,您可以修改 CartPoleImageWrapper 类中的 Reset() 方法来处理此差异。在调用包装环境的 Reset() 方法时,您可以简单地忽略种子参数。具体方法如下:

class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

def __init__(self, env):
    super(CartPoleImageWrapper, self).__init__(env)
    self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

def _get_image_observation(self):
    # Render the CartPole environment
    cartpole_image = self.render()

    # Resize the image to 84x84 pixels
    resized_image = cv2.resize(cartpole_image, (84, 84))
    # make it grayscale
    resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
    resized_image = np.expand_dims(resized_image, axis=-1)
    return resized_image

def reset(self, **kwargs):
    self.env.reset(**kwargs)  # Ignore the 'seed' argument
    return self._get_image_observation()

def step(self, action):
    observation, reward, terminated, info = self.env.step(action)
    return self._get_image_observation(), reward, terminated, info

通过此修改,您应该能够将 CartPoleImageWrapper 与 VecFrameStack 一起使用,而不会遇到与意外种子参数相关的 TypeError。

© www.soinside.com 2019 - 2024. All rights reserved.