如何使用LSTM?来自 sb3-contrib 的经常性 PPO

问题描述 投票:0回答:1

我正在致力于创建一个基于 LSTM 的强化学习模型,并尝试了解 sb3-contrib 的 Recurrent PPO 的工作原理。这是代码的简化示例:

# import gym
# from gym import spaces
# import torch
# import numpy as np
# from sb3_contrib import RecurrentPPO

class env_LSTM(gym.Env):
    def __init__(self, qnt_steps, qnt_features, qnt_actions, states_input):
        super(env_LSTM, self).__init__()
        self.states_input = states_input
        self.index_states_input = 0
        self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_steps, qnt_features), dtype=np.float64)
        self.action_space = spaces.Discrete(qnt_actions)
    def reset(self):
        self.index_states_input = 0
        return self.states_input[self.index_states_input]
    def step(self, action):
        self.index_states_input += 1
        done = self.index_states_input == len(self.states_input)-1
        return self.states_input[self.index_states_input], 0, done, {'a':1}
    
steps, features, actions = 10, 5, 3
    
states = torch.randn(20, steps, features)    
env = env_LSTM(steps,features,actions,states)    

model = RecurrentPPO('MlpLstmPolicy', env, verbose=1)

在这个场景中,我正在创建一个具有 10 个时间步长和 5 个特征的观察空间。这些状态是随机生成的,用于演示目的。

但是,我无法理解模型的架构。当我检查 model.policy 时,它返回:

model.policy

返回:

    RecurrentActorCriticPolicy(
    (features_extractor): FlattenExtractor(
        (flatten): Flatten(start_dim=1, end_dim=-1)
    )
    (pi_features_extractor): FlattenExtractor(
        (flatten): Flatten(start_dim=1, end_dim=-1)
    )
    (vf_features_extractor): FlattenExtractor(
        (flatten): Flatten(start_dim=1, end_dim=-1)
    )
    (mlp_extractor): MlpExtractor(
        (policy_net): Sequential(
        (0): Linear(in_features=256, out_features=64, bias=True)
        (1): Tanh()
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): Tanh()
        )
        (value_net): Sequential(
        (0): Linear(in_features=256, out_features=64, bias=True)
        (1): Tanh()
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): Tanh()
        )
    )
    (action_net): Linear(in_features=64, out_features=3, bias=True)
    (value_net): Linear(in_features=64, out_features=1, bias=True)
    (lstm_actor): LSTM(50, 256)
    (lstm_critic): LSTM(50, 256)
    )

LSTM层有50个输入特征,但根据shape=(10, 5)定义的observation_space应该是5个。它似乎是 10 乘以 5,并考虑 50 个不同的特征。

当我提供像 torch.randn(10, 5) 这样的输入并将其传递给 model.predict() 时,模型会正常生成预测,生成一个带有操作的元组和我相信的未来状态的预测。

在这部分:

        self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_steps, qnt_features), dtype=np.float64)

我尝试仅使用观察空间中的特征数量:

        self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_features,), dtype=np.float64)

但是当我将输入 torch.randn(10, 5) 传递给模型时,它会考虑 10 个不同的状态,每个状态有 5 个特征,并对每个状态进行预测,从而产生 10 个动作。

我不确定这是如何运作的。如果有人有见解,我会很感激解释。

此外,我在使用 PPO2 和 MlpLstmPolicy 时遇到了稳定基线问题。当我在数据集上使用 model.learn 时,它会消耗所有内存并冻结笔记本。我尝试减少批处理大小以消耗更少的内存,但它仍然挂起。我已经在环境中循环访问了数据集,没有任何问题;使用 model.learn 时出现问题。

谢谢您的帮助!

我尝试仅使用观察空间中的特征数量:

        self.observation_space = spaces.Box(low=0, high=1, shape=(qnt_features,), dtype=np.float64)

但是当我将输入 torch.randn(10, 5) 传递给模型时,它会考虑 10 个不同的状态,每个状态有 5 个特征,并对每个状态进行预测,从而产生 10 个动作。

deep-learning pytorch lstm recurrent-neural-network stable-baselines
1个回答
0
投票

我有同样的问题,似乎我们 FlattenExtractor 的 features_dim 参数只接受 int 变量而不是元组。你还有什么线索吗?谢谢!

© www.soinside.com 2019 - 2024. All rights reserved.