我正在致力于创建一个基于 LSTM 的强化学习模型,并尝试了解 sb3-contrib 的 Recurrent PPO 的工作原理。这是代码的简化示例:
# import gym
# from gym import spaces
# import torch
# import numpy as np
# from sb3_contrib import RecurrentPPO
class env_LSTM(gym.Env):
def __init__(self, qnt_steps, qnt_features, qnt_actions, states_input):
super(env_LSTM, self).__init__()
self.states_input = states_input
self.index_states_input = 0
self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_steps, qnt_features), dtype=np.float64)
self.action_space = spaces.Discrete(qnt_actions)
def reset(self):
self.index_states_input = 0
return self.states_input[self.index_states_input]
def step(self, action):
self.index_states_input += 1
done = self.index_states_input == len(self.states_input)-1
return self.states_input[self.index_states_input], 0, done, {'a':1}
steps, features, actions = 10, 5, 3
states = torch.randn(20, steps, features)
env = env_LSTM(steps,features,actions,states)
model = RecurrentPPO('MlpLstmPolicy', env, verbose=1)
在这个场景中,我正在创建一个具有 10 个时间步长和 5 个特征的观察空间。这些状态是随机生成的,用于演示目的。
但是,我无法理解模型的架构。当我检查 model.policy 时,它返回:
model.policy
返回:
RecurrentActorCriticPolicy(
(features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(pi_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(vf_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(mlp_extractor): MlpExtractor(
(policy_net): Sequential(
(0): Linear(in_features=256, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
(value_net): Sequential(
(0): Linear(in_features=256, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
)
(action_net): Linear(in_features=64, out_features=3, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
(lstm_actor): LSTM(50, 256)
(lstm_critic): LSTM(50, 256)
)
LSTM层有50个输入特征,但根据shape=(10, 5)定义的observation_space应该是5个。它似乎是 10 乘以 5,并考虑 50 个不同的特征。
当我提供像 torch.randn(10, 5) 这样的输入并将其传递给 model.predict() 时,模型会正常生成预测,生成一个带有操作的元组和我相信的未来状态的预测。
在这部分:
self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_steps, qnt_features), dtype=np.float64)
我尝试仅使用观察空间中的特征数量:
self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_features,), dtype=np.float64)
但是当我将输入 torch.randn(10, 5) 传递给模型时,它会考虑 10 个不同的状态,每个状态有 5 个特征,并对每个状态进行预测,从而产生 10 个动作。
我不确定这是如何运作的。如果有人有见解,我会很感激解释。
此外,我在使用 PPO2 和 MlpLstmPolicy 时遇到了稳定基线问题。当我在数据集上使用 model.learn 时,它会消耗所有内存并冻结笔记本。我尝试减少批处理大小以消耗更少的内存,但它仍然挂起。我已经在环境中循环访问了数据集,没有任何问题;使用 model.learn 时出现问题。
谢谢您的帮助!
我尝试仅使用观察空间中的特征数量:
self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_features,), dtype=np.float64)
但是当我将输入 torch.randn(10, 5) 传递给模型时,它会考虑 10 个不同的状态,每个状态有 5 个特征,并对每个状态进行预测,从而产生 10 个动作。
我有同样的问题,似乎我们 FlattenExtractor 的 features_dim 参数只接受 int 变量而不是元组。你还有什么线索吗?谢谢!