我已经使用 stable-baselines3 训练了 A2C 模型(MlpPolicy)(我对强化学习很陌生,发现这是一个很好的起点)。 然而,我现在想使用XRL(eXplainable Reinforcement Learning)方法来更好地理解模型。我决定使用 DeepSHAP,因为它有一个很好的实现,而且我熟悉 SHAP。 DeepSHAP 在 PyTorch 上运行,PyTorch 是 stable-baselines3 背后的底层框架。所以我的目标是从 stable-baselines3 模型中提取底层 PyTorch 模型。但是,我对此有一些问题。
我找到了以下线程:https://github.com/hill-a/stable-baselines/issues/372 这个帖子确实对我有一点帮助,但是,由于 A2C 的架构与这个帖子中使用的模型不同,我还无法解决我的问题。
据我了解,stable-baselines3 提供了使用导出模型的选项
model.policy.state_dict()
但是,我很难导入通过该方法导出的内容。
打印时
A2C_model.policy
我大致了解了 PyTorch 模型的结构。输出:
ActorCriticPolicy(
(features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(pi_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(vf_features_extractor): FlattenExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(mlp_extractor): MlpExtractor(
(policy_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
(value_net): Sequential(
(0): Linear(in_features=49, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
)
(action_net): Linear(in_features=64, out_features=5, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
)
我尝试自己重新创建它,但我对 PyTorch 还不够熟练,还无法让它工作......
所以我的问题是:如何将 stable_baselines3 模型导出到 PyTorch?
我尝试根据打印 A2C_model.policy 的输出在 PyTorch 中重新构建模型架构。我的代码目前是:
import torch as th
import torch.nn as nn
class PyTorchMlp(nn.Module):
def __init__(self):
nn.Module.__init__(self)
n_inputs = 49
n_actions = 5
self.features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.pi_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.vf_features_extractor = nn.Flatten(start_dim = 1, end_dim = -1)
self.mlp_extractor = nn.Sequentail(
self.policy_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
),
self.value_net = nn.Sequential(
nn.Linear(in_features = n_inputs, out_features = 64),
nn.Tanh(),
nn.Linear(in_features = 64, out_features = 64),
nn.Tanh()
)
)
self.action_net = nn.Linear(in_features = 64, out_features = 5)
self.value_net = nn.Linear(in_features = 64, out_features = 1)
def forward(self, x):
pass
如果您只想将其导出为 pytorch 模型以便使用 shap 框架中的 DeepExplainer,您所需要做的就是创建一个类来将模型的
policy_net
和 action_net
包装在一起。我的解决方案是实现 stable-baselines3 的 PPO (MLP) 模型,但我确信它对于 A2C 来说不会有什么不同。
我的 PPO (MLP) 模型的包装类:
import shap
import torch
import torch.nn as nn
from stable_baselines3 import PPO
class sb3Wrapper(nn.Module):
def __init__(self, model):
super(sb3Wrapper,self).__init__()
self.extractor = model.policy.mlp_extractor
self.policy_net = model.policy.mlp_extractor.policy_net
self.action_net = model.policy.action_net
def forward(self,x):
x = self.policy_net(x)
x = self.action_net(x)
return x
关于 shap 框架深度解释器,您需要确定一些事情
DeepExplainer
函数的模型和状态数据(作为火炬张量)位于同一设备上(即“cuda”/“cpu”)torch.FloatTensor()
函数以下是我的实现中的几行内容,可以帮助您: (我提取所有数据并将其存储在数据框中,因为我还在执行其他分析)
model = PPO.load(model_path, device='cuda')
state_log = np.array(df['observation'].values.tolist())
data = torch.FloatTensor(state_log).to('cuda')
model = sb3Wrapper(model)
explainer = shap.DeepExplainer(model, data)
shap_vals= explainer.shap_values(data)
参考资料和有用的链接:
DeepExplainer
的有用信息。