如何将使用 stable-baselines3 创建的 A2C 模型导出到 PyTorch?

问题描述 投票:0回答:1

我已经使用 stable-baselines3 训练了 A2C 模型(MlpPolicy)(我对强化学习很陌生,发现这是一个很好的起点)。 然而,我现在想使用XRL(eXplainable Reinforcement Learning)方法来更好地理解模型。我决定使用 DeepSHAP,因为它有一个很好的实现,而且我熟悉 SHAP。 DeepSHAP 在 PyTorch 上运行,PyTorch 是 stable-baselines3 背后的底层框架。所以我的目标是从 stable-baselines3 模型中提取底层 PyTorch 模型。但是,我对此有一些问题。

我找到了以下线程:https://github.com/hill-a/stable-baselines/issues/372 这个帖子确实对我有一点帮助,但是,由于 A2C 的架构与这个帖子中使用的模型不同,我还无法解决我的问题。

据我了解,stable-baselines3 提供了使用导出模型的选项

model.policy.state_dict()

但是,我很难导入通过该方法导出的内容。

打印时

A2C_model.policy

我大致了解了 PyTorch 模型的结构。输出:

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=49, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=49, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=5, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)

我尝试自己重新创建它,但我对 PyTorch 还不够熟练,还无法让它工作......

所以我的问题是:如何将 stable_baselines3 模型导出到 PyTorch?

我尝试根据打印 A2C_model.policy 的输出在 PyTorch 中重新构建模型架构。我的代码目前是:

import torch as th
import torch.nn as nn

class PyTorchMlp(nn.Module):  
        def __init__(self):
                nn.Module.__init__(self)

                n_inputs = 49
                n_actions = 5
        
                self.features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.pi_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.vf_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
        
                self.mlp_extractor = nn.Sequentail(
                    self.policy_net = nn.Sequential(
                        nn.Linear(in_features = n_inputs, out_features = 64),
                        nn.Tanh(),
                        nn.Linear(in_features = 64, out_features = 64),
                        nn.Tanh()
                    ),
        
                    self.value_net = nn.Sequential(
                        nn.Linear(in_features = n_inputs, out_features = 64),
                        nn.Tanh(),
                        nn.Linear(in_features = 64, out_features = 64),
                        nn.Tanh()
                    )
                )
        
                self.action_net = nn.Linear(in_features = 64, out_features = 5)
        
                self.value_net = nn.Linear(in_features = 64, out_features = 1)
        
    
            def forward(self, x):
                pass
deep-learning pytorch neural-network reinforcement-learning stable-baselines
1个回答
1
投票

如果您只想将其导出为 pytorch 模型以便使用 shap 框架中的 DeepExplainer,您所需要做的就是创建一个类来将模型的

policy_net
action_net
包装在一起。我的解决方案是实现 stable-baselines3 的 PPO (MLP) 模型,但我确信它对于 A2C 来说不会有什么不同。

我的 PPO (MLP) 模型的包装类:

import shap
import torch
import torch.nn as nn
from stable_baselines3 import PPO

class sb3Wrapper(nn.Module):
    def __init__(self, model):
        super(sb3Wrapper,self).__init__()
        self.extractor = model.policy.mlp_extractor
        self.policy_net = model.policy.mlp_extractor.policy_net
        self.action_net = model.policy.action_net

    def forward(self,x):
        x = self.policy_net(x)
        x = self.action_net(x)
        return x

关于 shap 框架深度解释器,您需要确定一些事情

  1. 您需要确保传递到
    DeepExplainer
    函数的模型和状态数据(作为火炬张量)位于同一设备上(即“cuda”/“cpu”)
  2. 如果您的状态数据是连续的,请确保使用
    torch.FloatTensor()
    函数

以下是我的实现中的几行内容,可以帮助您: (我提取所有数据并将其存储在数据框中,因为我还在执行其他分析)

model = PPO.load(model_path, device='cuda')
state_log = np.array(df['observation'].values.tolist())
data = torch.FloatTensor(state_log).to('cuda')
model = sb3Wrapper(model)
explainer = shap.DeepExplainer(model, data)
shap_vals= explainer.shap_values(data)

参考资料和有用的链接:

  • Shap 框架 Github 问题,该人想要将 stable_baselines3 DQN 与 DeepExplainer 一起使用
  • 关于访问网络每一层的
  • stable-baselines3 Github问题
    • 如果您想逐层包装模型,这可能会对您有所帮助(我还没有尝试过)
  • Kaggle 代码,其中包含有关在 pytorch 模型上使用
    DeepExplainer
    的有用信息。
© www.soinside.com 2019 - 2024. All rights reserved.