解释方差在一个范围内时编写回调

问题描述 投票:0回答:0

我目前正在使用 stable_baselines3 库和 gym_anytrading 编写强化学习模型。我已经为一个环境编写了代码来训练模型,并且有很多时间步长。

但是,当

explained variance
处于不理想的水平时,模型通常会完成训练。因此,我想编写一个回调函数,当
explained variance
在一个范围内时帮助停止训练,例如在 0.9 到 1 之间

这是我到目前为止创建的环境。

env_maker = lambda: MyCustomEnv(df=df, frame_bound=(12,30660), window_size=12)
env = DummyVecEnv([env_maker])

model = A2C('MlpPolicy', env, verbose=1, policy_kwargs=dict(net_arch=[dict(pi=[128, 256, 128], vf=[128, 256, 128])]))

# Pass the custom callback to the learn() method
model.learn(total_timesteps=1000000, callback=custom_stop_callback)

这是回调函数,当

explained variance
高于某个值时,应该停止模型。

class CustomLogger(logger.Logger):
    def __init__(self, folder, output_formats, *args, **kwargs):
        super().__init__(folder, output_formats, *args, **kwargs)
        self.buffer = []

    def get_writer(self) -> KVWriter:
        return self

    def _write(self, key_values, key_excluded):
        self.buffer.append((key_values, key_excluded))

class CustomStopCallback(BaseCallback):
    def __init__(self, logger, explained_variance_threshold: float, value_loss_threshold: float, starting_step: int = 0):
        super(CustomStopCallback, self).__init__()
        self.logger = logger
        self.explained_variance_threshold = explained_variance_threshold
        self.value_loss_threshold = value_loss_threshold
        self.starting_step = starting_step

    def _on_step(self) -> bool:
        return True

    def _on_rollout_end(self) -> None:
        if self.num_timesteps >= self.starting_step:
            log_buffer = self.logger.buffer
            explained_variance = None
            value_loss = None

            for record in log_buffer:
                key_values, _ = record
                if "explained_variance" in key_values:
                    explained_variance = key_values["explained_variance"]
                if "value_loss" in key_values:
                    value_loss = key_values["value_loss"]

            if explained_variance is not None and value_loss is not None:
                if explained_variance >= self.explained_variance_threshold and value_loss > self.value_loss_threshold:
                    print(f"Stopping training at step {self.num_timesteps} due to specified threshold conditions.")
                    self.model.set_attr('stop_training', True)

folder = "logs"
logger.configure(folder=folder)

# Instantiate the custom callback with specified thresholds
custom_stop_callback = CustomStopCallback(logger, explained_variance_threshold=0.9, value_loss_threshold=0, starting_step=10000)

感谢任何有关如何解决此问题的帮助!

python-3.x reinforcement-learning openai-gym stable-baselines
© www.soinside.com 2019 - 2024. All rights reserved.