执行终止时停止稳定基线学习方法

问题描述 投票:0回答:1

我正在使用 satble-baselines3 的“学习”方法来训练模型。问题是我的代理中的奖励是负数,我希望代理有尽快停止的冲动。我的意思是,我的想法是和我的经纪人一起找到某件事。在每一步中,代理都会因再进行一次迭代而获得负奖励,并且这种情况会持续下去,直到找到成功状态(因此迭代次数越多,奖励就越高)。我的问题是,即使我的 _is_done() 方法在这种情况下返回 True(触发代理重置),模型也不会停止训练。因此,我担心代理不会尝试尽快找到成功的状态,而是会做完全相反的事情,尝试尽可能接近最后一个训练步骤进行重置,这就是我想要的原因当找到成功的状态时打破学习循环,这样越早找到状态,负累积奖励就会越高(绝对值越低)。

这种情况如何让学习停止?

我的实现是这样的:

环境:

def step(self, action, episode = -1, step = -1, reward_file_path = None, verbose = False):
        """"
        Equivalent to Iteration step. Method used to perform one iteration step by taking the action, observing the resulting state and computing the reward
        """
        print(f"\n\t\t************************ Current step #{self.current_step} ***************************************************")
        if episode != self.current_episode and episode != -1:
            self.current_episode = episode
        if step != self.current_step and step != -1:
            self.current_step = step
        truncated = False
        try:
             ... 

            # Calculate reward and done flag based on current state
            self._calculate_reward()

            # Store the reward
            self.store_reward(reward_file_path)
            # Returns true if the succesful state is found
            self.terminated = self._is_done()

            self.obs_end = self._get_observation()

            self.current_step += 1
        return self._get_observation(), self.reward, self.terminated, truncated, {}

在我的特工班:

self.model.learn(total_timesteps=steps_to_train, callback=None, log_interval=1, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar) # I've tried with both reset_num_timesteps True and False

我也尝试过使用回调,但即使返回False,训练也不会停止:

class StopOnSuccessCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(StopOnSuccessCallback, self).__init__(verbose)

    def _on_step(self):
        print("ON STEP")
        # Access the environment from the model
        terminated = self.model.env.envs[0].terminated  # Assumes that the first environment has the termination attribute
        if np.any(terminated):
            self.model.logger.info("Termination signal received. Stopping training.")
            print("Termination signal received. Stopping training.")
            return  False  # Stop training
        return True # Return whether the trainig stops or not

更多代码...

  callback = StopOnSuccessCallback()  
  self.model.learn(total_timesteps=steps_to_train, callback=callback, log_interval=log_interval,
                                tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar)
python-3.x reinforcement-learning openai-gym stable-baselines
1个回答
0
投票

对于那些感兴趣的人,我最终修改了我的回调,而不是获取终止属性,而是调用一个在内部调用的公共函数

_is_done()

class StopOnSuccessCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(StopOnSuccessCallback, self).__init__(verbose)

    def _on_step(self):
        env =  self.model.env.envs[0]
        # Access the environment from the model and check the "terminated" attribute
        terminated = env.is_done()
        if terminated:
            self.model.logger.info("Termination signal received. Stopping training.")
        return not terminated # Return whether the training stops or not
© www.soinside.com 2019 - 2024. All rights reserved.