我正在尝试在模型的测试步骤中计算 SHAP 值。代码如下:
# For setting up the dataloaders
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# Load the MNIST train dataset
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a subset of the train dataset with 2000 images
train_data = Subset(mnist_train, range(2000))
# Load the MNIST test dataset
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Create a subset of the test dataset with 100 images
test_data = Subset(mnist_test, range(100))
# Create data loaders for the train and test datasets
train_loader = DataLoader(train_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
这是我的模型及其训练的代码
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import shap
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 256)
self.layer_3 = nn.Linear(256, 10)
self.explainer_created = False
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.nn.functional.softplus(self.layer_1(x))
x = torch.nn.functional.softplus(self.layer_2(x))
x = torch.log_softmax(self.layer_3(x), dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.nll_loss(y_hat, y)
self.log('train_loss', loss)
return loss
def test_step(self, batch, batch_idx):
# Call the create_explainer method only once
if not self.explainer_created:
self.create_explainer(train_data)
self.explainer_created = True
x, y = batch
# Enable gradients
with torch.set_grad_enabled(True):
y_hat = self(x)
loss = nn.functional.nll_loss(y_hat, y)
self.log('test_loss', loss)
# Compute SHAP values for the test images
x.requires_grad = True
print(x.requires_grad)
shap_values = self.explainer.shap_values(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def create_explainer(self, train_data):
x_train = torch.stack([train_data[i][0] for i in range(len(train_data))])
# Create a background dataset from the train data
background_indices = torch.randperm(len(x_train))[:100]
background = x_train[background_indices]
# Create a DeepExplainer object
self.explainer = shap.DeepExplainer(self, background)
print("created explainer")
# data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
#train_data = # your train dataset here
#test_data = # your test dataset here
#train_loader = DataLoader(train_data, batch_size=32)
#test_loader = DataLoader(test_data, batch_size=32)
# model
model = LitModel()
# training
trainer = pl.Trainer(max_epochs=3)
trainer.fit(model, train_loader)
一切正常,直到这里。当我尝试做
# testing
trainer.test(model, test_loader)
我看到错误:
1 # testing
----> 2 trainer.test(model, test_loader)
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:785, in Trainer.test(self, model, dataloaders, ckpt_path, verbose, datamodule)
783 raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
784 self.strategy._lightning_module = model or self.lightning_module
--> 785 return call._call_and_handle_interrupt(
786 self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
787 )
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
37 else:
---> 38 return trainer_fn(*args, **kwargs)
40 except _TunerExitException:
41 trainer._call_teardown_hook()
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:834, in Trainer._test_impl(self, model, dataloaders, ckpt_path, verbose, datamodule)
831 self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8
833 # run test
--> 834 results = self._run(model, ckpt_path=self.ckpt_path)
836 assert self.state.stopped
837 self.testing = False
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103, in Trainer._run(self, model, ckpt_path)
1099 self._checkpoint_connector.restore_training_state()
1101 self._checkpoint_connector.resume_end()
-> 1103 results = self._run_stage()
1105 log.detail(f"{self.__class__.__name__}: trainer tearing down")
1106 self._teardown()
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1179, in Trainer._run_stage(self)
1176 self.strategy.dispatch(self)
1178 if self.evaluating:
-> 1179 return self._run_evaluate()
1180 if self.predicting:
1181 return self._run_predict()
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1219, in Trainer._run_evaluate(self)
1214 self._evaluation_loop.trainer = self
1216 with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
1217 self.accelerator, self._inference_mode
1218 ):
-> 1219 eval_loop_results = self._evaluation_loop.run()
1221 # remove the tensors from the eval results
1222 for result in eval_loop_results:
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, **kwargs)
197 try:
198 self.on_advance_start(*args, **kwargs)
--> 199 self.advance(*args, **kwargs)
200 self.on_advance_end()
201 self._restarting = False
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152, in EvaluationLoop.advance(self, *args, **kwargs)
150 if self.num_dataloaders > 1:
151 kwargs["dataloader_idx"] = dataloader_idx
--> 152 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
154 # store batch level output per dataloader
155 self._outputs.append(dl_outputs)
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, **kwargs)
197 try:
198 self.on_advance_start(*args, **kwargs)
--> 199 self.advance(*args, **kwargs)
200 self.on_advance_end()
201 self._restarting = False
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137, in EvaluationEpochLoop.advance(self, data_fetcher, dl_max_batches, kwargs)
134 self.batch_progress.increment_started()
136 # lightning module methods
--> 137 output = self._evaluation_step(**kwargs)
138 output = self._evaluation_step_end(output)
140 self.batch_progress.increment_processed()
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234, in EvaluationEpochLoop._evaluation_step(self, **kwargs)
223 """The evaluation step (validation_step or test_step depending on the trainer's state).
224
225 Args:
(...)
231 the outputs of the step
232 """
233 hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 234 output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
236 return output
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
1482 return
1484 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1485 output = fn(*args, **kwargs)
1487 # restore current_fx when nested context
1488 pl_module._current_fx_name = prev_fx_name
File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:399, in Strategy.test_step(self, *args, **kwargs)
397 with self.precision_plugin.test_step_context():
398 assert isinstance(self.model, TestStep)
--> 399 return self.model.test_step(*args, **kwargs)
Cell In[4], line 47, in LitModel.test_step(self, batch, batch_idx)
45 x.requires_grad = True
46 print(x.requires_grad)
---> 47 shap_values = self.explainer.shap_values(x)
File myenvs/condaenv/lib/python3.10/site-packages/shap/explainers/_deep/__init__.py:124, in Deep.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
90 def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True):
91 """ Return approximate SHAP values for the model applied to the data given by X.
92
93 Parameters
(...)
122 were chosen as "top".
123 """
--> 124 return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
File myenvs/condaenv/lib/python3.10/site-packages/shap/explainers/_deep/deep_pytorch.py:185, in PyTorchDeep.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
183 # run attribution computation graph
184 feature_ind = model_output_ranks[j, i]
--> 185 sample_phis = self.gradient(feature_ind, joint_x)
186 # assign the attributions to the right part of the output arrays
187 if self.interim:
File myenvs/condaenv/lib/python3.10/site-packages/shap/explainers/_deep/deep_pytorch.py:121, in PyTorchDeep.gradient(self, idx, inputs)
119 else:
120 for idx, x in enumerate(X):
--> 121 grad = torch.autograd.grad(selected, x,
122 retain_graph=True if idx + 1 < len(X) else None,
123 allow_unused=True)[0]
124 if grad is not None:
125 grad = grad.cpu().numpy()
File myenvs/condaenv/lib/python3.10/site-packages/torch/autograd/__init__.py:300, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
298 return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_)
299 else:
--> 300 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
301 t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
302 allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
如果你正在测试代码,我建议在 ipython notebook 中运行它。
正如您在上面看到的,我尝试在测试步骤中启用渐变,并将 torch.relu() 替换为 torch.nn.functional.softplus()。这些都没有用。
我希望 shap.DeepExplainer 对象按预期执行计算并输出值,而不是如上所示的错误
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
。
事实上,当我尝试使用一些随机值从外部调用模型的
explainer
对象时,它工作得很好。仅当我尝试在 test_step() 中执行相同操作时才会遇到此错误
这个
model.explainer.shap_values(torch.rand([1,1,28,28], device="cuda"))
非常好用。