在构建从其他类继承的类的实例后,我试图更改
seed
方法内的“__init__
”的默认值。我在“lab2d/dmlab2d/__init__.py
”中有“环境”类。当我打印我的类的这个参数
print(f"dmlab2d {env._env.__class__.__bases__[0].__bases__[0].__dict__}")
时,输出是
dmlab2d {'__module__': 'dmlab2d', '__doc__': 'Environment class for DeepMind Lab2D.\n\n This environment extends the `dm_env` interface with additional methods.\n For details, see https://github.com/deepmind/dm_env\n ', '__init__': <function Environment.__init__ at 0x78d207cfd5a0>, 'reset': <function Environment.reset at 0x78d207cfd510>, '_read_action': <function Environment._read_action at 0x78d207cfd480>, 'step': <function Environment.step at 0x78d207cfd120>, 'observation': <function Environment.observation at 0x78d207cfd360>, 'observation_spec': <function Environment.observation_spec at 0x78d207cfd2d0>, '_make_observation_spec': <function Environment._make_observation_spec at 0x78d207cfd240>, '_make_action_spec': <function Environment._make_action_spec at 0x78d207cfd090>, 'action_spec': <function Environment.action_spec at 0x78d207cfc310>, 'events': <function Environment.events at 0x78d207cfc280>, 'list_property': <function Environment.list_property at 0x78d207cfc1f0>, 'write_property': <function Environment.write_property at 0x78d207cfc0d0>, 'read_property': <function Environment.read_property at 0x78d207cfc040>, '__abstractmethods__': frozenset(), '_abc_impl': <_abc._abc_data object at 0x78d207b66400>}
有人可以建议我如何在固有类属性的这种纠结情况下更改
seed
类中的 Environment
值吗?
这是我的代码
import dmlab2d
import gymnasium as gym
from matplotlib import pyplot as plt
from gymnasium import spaces
from meltingpot import substrate
from ml_collections import config_dict
import numpy as np
from ray.rllib.env import multi_agent_env
class MeltingPotEnv(multi_agent_env.MultiAgentEnv):
"""An adapter between the Melting Pot substrates and RLLib MultiAgentEnv."""
def __init__(self, env: dmlab2d.Environment, max_cycles: int = MAX_CYCLES):
"""Initializes the instance.
Args:
env: dmlab2d environment to wrap. Will be closed when this wrapper closes.
"""
self._env = env
self._num_players = len(self._env.observation_spec())
self._ordered_agent_ids = [
PLAYER_STR_FORMAT.format(index=index)
for index in range(self._num_players)
]
# RLLib requires environments to have the following member variables:
# observation_space, action_space, and _agent_ids
self._agent_ids = set(self._ordered_agent_ids)
# RLLib expects a dictionary of agent_id to observation or action,
# Melting Pot uses a tuple, so we convert
self.observation_space = self._convert_spaces_tuple_to_dict(
spec_to_space(self._env.observation_spec()),
remove_world_observations=True)
self.action_space = self._convert_spaces_tuple_to_dict(
spec_to_space(self._env.action_spec()))
self.max_cycles = max_cycles
self.num_cycles = 0
super().__init__()
def reset(self, *args, **kwargs):
"""See base class."""
timestep = self._env.reset()
self.num_cycles = 0
return timestep_to_observations(timestep), {}
def step(self, action_dict):
"""See base class."""
actions = [action_dict[agent_id] for agent_id in self._ordered_agent_ids]
timestep = self._env.step(actions)
rewards = {
agent_id: timestep.reward[index]
for index, agent_id in enumerate(self._ordered_agent_ids)
}
self.num_cycles += 1
termination = timestep.last()
done = { '__all__': termination}
truncation = self.num_cycles >= self.max_cycles
truncations = {agent_id: truncation for agent_id in self._ordered_agent_ids}
info = {}
observations = timestep_to_observations(timestep)
return observations, rewards, done, truncations, info
def get_dmlab2d_env(self):
"""Returns the underlying DM Lab2D environment."""
return self._env
# Metadata is required by the gym `Env` class that we are extending, to show
# which modes the `render` method supports.
metadata = {'render.modes': ['rgb_array']}
def render(self) -> np.ndarray:
"""Render the environment.
This allows you to set `record_env` in your training config, to record
videos of gameplay.
Returns:
np.ndarray: This returns a numpy.ndarray with shape (x, y, 3),
representing RGB values for an x-by-y pixel image, suitable for turning
into a video.
"""
observation = self._env.observation()
world_rgb = observation[0]['WORLD.RGB']
# RGB mode is used for recording videos
return world_rgb
def _convert_spaces_tuple_to_dict(
self,
input_tuple: spaces.Tuple,
remove_world_observations: bool = False) -> spaces.Dict:
"""Returns spaces tuple converted to a dictionary.
Args:
input_tuple: tuple to convert.
remove_world_observations: If True will remove non-player observations.
"""
return spaces.Dict({
agent_id: (remove_world_observations_from_space(input_tuple[i])
if remove_world_observations else input_tuple[i])
for i, agent_id in enumerate(self._ordered_agent_ids)
})
env = substrate.build(env_config['substrate'], roles=env_config['roles'])
env = MeltingPotEnv(env)
seed
是用于此赋值的构造函数参数:
self._rng = np.random.RandomState(seed=seed)
如果要设置
seed
,可以通过构造函数参数来完成
Environment(env, observation_names, seed=your_seed)