在 Python 中,我使用
stablebaselines3
和 gymnasium
来实现自定义 DQN。使用 atari 游戏我测试了代理并正常工作,现在我还需要在像 CartPole
这样的环境中测试它
问题是这种环境不会返回帧作为观察,而是仅返回一个向量。
因此,我需要一种方法来返回 CartPole 帧作为观察,并应用与 Atari 游戏相同的预处理内容(例如将游戏的 4 帧堆叠在一起)
我在互联网上搜索了如何做到这一点,经过一番尝试后我想出了这个代码,但是我遇到了一些问题。
这是代码:
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
import numpy as np
import cv2
class CartPoleImageWrapper(gym.Wrapper):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, env):
super(CartPoleImageWrapper, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _get_image_observation(self):
# Render the CartPole environment
cartpole_image = self.render()
# Resize the image to 84x84 pixels
resized_image = cv2.resize(cartpole_image, (84, 84))
# make it grayscale
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
resized_image = np.expand_dims(resized_image, axis=-1)
return resized_image
def reset(self):
self.env.reset()
return self._get_image_observation()
def step(self, action):
observation, reward, terminated, info = self.env.step(action)
return self._get_image_observation(), reward, terminated, info
env = CartPoleImageWrapper(CartPoleEnv(render_mode='rgb_array'))
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
#exit()
vec_env.close()
当我打电话
env.reset()
时,错误是这样的:
Traceback (most recent call last):
File "/data/g.carfi/rl/tmp.py", line 41, in <module>
obs = vec_env.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 41, in reset
observation = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 113, in reset
observations = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/monitor.py", line 83, in reset
return self.env.reset(**kwargs)
TypeError: reset() got an unexpected keyword argument 'seed'
我该如何解决这个问题?
您遇到的问题是由于 CartPoleEnv 类的 reset() 方法不接受种子参数,但它似乎是由 VecEnv 内部传递的。
要解决此问题,您可以修改 CartPoleImageWrapper 类中的 Reset() 方法来处理此差异。在调用包装环境的 Reset() 方法时,您可以简单地忽略种子参数。具体方法如下:
class CartPoleImageWrapper(gym.Wrapper):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, env):
super(CartPoleImageWrapper, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _get_image_observation(self):
# Render the CartPole environment
cartpole_image = self.render()
# Resize the image to 84x84 pixels
resized_image = cv2.resize(cartpole_image, (84, 84))
# make it grayscale
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
resized_image = np.expand_dims(resized_image, axis=-1)
return resized_image
def reset(self, **kwargs):
self.env.reset(**kwargs) # Ignore the 'seed' argument
return self._get_image_observation()
def step(self, action):
observation, reward, terminated, info = self.env.step(action)
return self._get_image_observation(), reward, terminated, info
通过此修改,您应该能够将 CartPoleImageWrapper 与 VecFrameStack 一起使用,而不会遇到与意外种子参数相关的 TypeError。