我目前正在使用 stable_baselines3 库和 gym_anytrading 编写强化学习模型。我已经为一个环境编写了代码来训练模型,并且有很多时间步长。
但是,当
explained variance
处于不理想的水平时,模型通常会完成训练。因此,我想编写一个回调函数,当explained variance
在一个范围内时帮助停止训练,例如在 0.9 到 1 之间
这是我到目前为止创建的环境。
env_maker = lambda: MyCustomEnv(df=df, frame_bound=(12,30660), window_size=12)
env = DummyVecEnv([env_maker])
model = A2C('MlpPolicy', env, verbose=1, policy_kwargs=dict(net_arch=[dict(pi=[128, 256, 128], vf=[128, 256, 128])]))
# Pass the custom callback to the learn() method
model.learn(total_timesteps=1000000, callback=custom_stop_callback)
这是回调函数,当
explained variance
高于某个值时,应该停止模型。
class CustomLogger(logger.Logger):
def __init__(self, folder, output_formats, *args, **kwargs):
super().__init__(folder, output_formats, *args, **kwargs)
self.buffer = []
def get_writer(self) -> KVWriter:
return self
def _write(self, key_values, key_excluded):
self.buffer.append((key_values, key_excluded))
class CustomStopCallback(BaseCallback):
def __init__(self, logger, explained_variance_threshold: float, value_loss_threshold: float, starting_step: int = 0):
super(CustomStopCallback, self).__init__()
self.logger = logger
self.explained_variance_threshold = explained_variance_threshold
self.value_loss_threshold = value_loss_threshold
self.starting_step = starting_step
def _on_step(self) -> bool:
return True
def _on_rollout_end(self) -> None:
if self.num_timesteps >= self.starting_step:
log_buffer = self.logger.buffer
explained_variance = None
value_loss = None
for record in log_buffer:
key_values, _ = record
if "explained_variance" in key_values:
explained_variance = key_values["explained_variance"]
if "value_loss" in key_values:
value_loss = key_values["value_loss"]
if explained_variance is not None and value_loss is not None:
if explained_variance >= self.explained_variance_threshold and value_loss > self.value_loss_threshold:
print(f"Stopping training at step {self.num_timesteps} due to specified threshold conditions.")
self.model.set_attr('stop_training', True)
folder = "logs"
logger.configure(folder=folder)
# Instantiate the custom callback with specified thresholds
custom_stop_callback = CustomStopCallback(logger, explained_variance_threshold=0.9, value_loss_threshold=0, starting_step=10000)
感谢任何有关如何解决此问题的帮助!