在 LightningModule 网络的测试步骤中计算 SHAP 值

问题描述 投票:0回答:0

我正在尝试在模型的测试步骤中计算 SHAP 值。代码如下:

# For setting up the dataloaders
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST train dataset
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a subset of the train dataset with 2000 images
train_data = Subset(mnist_train, range(2000))

# Load the MNIST test dataset
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create a subset of the test dataset with 100 images
test_data = Subset(mnist_test, range(100))

# Create data loaders for the train and test datasets
train_loader = DataLoader(train_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)

这是我的模型及其训练的代码

import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

import shap

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)
        self.explainer_created = False

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.nn.functional.softplus(self.layer_1(x))
        x = torch.nn.functional.softplus(self.layer_2(x))
        x = torch.log_softmax(self.layer_3(x), dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.nll_loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        # Call the create_explainer method only once
        if not self.explainer_created:
            self.create_explainer(train_data)
            self.explainer_created = True
            
        x, y = batch

        # Enable gradients
        with torch.set_grad_enabled(True):
            y_hat = self(x)
            loss = nn.functional.nll_loss(y_hat, y)
            self.log('test_loss', loss)

            # Compute SHAP values for the test images
            x.requires_grad = True
            print(x.requires_grad)
            shap_values = self.explainer.shap_values(x)
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)
    
    def create_explainer(self, train_data):
        x_train = torch.stack([train_data[i][0] for i in range(len(train_data))])
        
        # Create a background dataset from the train data
        background_indices = torch.randperm(len(x_train))[:100]
        background = x_train[background_indices]
        
        # Create a DeepExplainer object
        self.explainer = shap.DeepExplainer(self, background)
        print("created explainer")
        
# data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

#train_data = # your train dataset here
#test_data = # your test dataset here

#train_loader = DataLoader(train_data, batch_size=32)
#test_loader = DataLoader(test_data, batch_size=32)

# model
model = LitModel()

# training
trainer = pl.Trainer(max_epochs=3)
trainer.fit(model, train_loader)

一切正常,直到这里。当我尝试做

# testing
trainer.test(model, test_loader)

我看到错误:

1 # testing
----> 2 trainer.test(model, test_loader)

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:785, in Trainer.test(self, model, dataloaders, ckpt_path, verbose, datamodule)
    783     raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
    784 self.strategy._lightning_module = model or self.lightning_module
--> 785 return call._call_and_handle_interrupt(
    786     self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
    787 )

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37     else:
---> 38         return trainer_fn(*args, **kwargs)
     40 except _TunerExitException:
     41     trainer._call_teardown_hook()

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:834, in Trainer._test_impl(self, model, dataloaders, ckpt_path, verbose, datamodule)
    831 self._tested_ckpt_path = self.ckpt_path  # TODO: remove in v1.8
    833 # run test
--> 834 results = self._run(model, ckpt_path=self.ckpt_path)
    836 assert self.state.stopped
    837 self.testing = False

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103, in Trainer._run(self, model, ckpt_path)
   1099 self._checkpoint_connector.restore_training_state()
   1101 self._checkpoint_connector.resume_end()
-> 1103 results = self._run_stage()
   1105 log.detail(f"{self.__class__.__name__}: trainer tearing down")
   1106 self._teardown()

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1179, in Trainer._run_stage(self)
   1176 self.strategy.dispatch(self)
   1178 if self.evaluating:
-> 1179     return self._run_evaluate()
   1180 if self.predicting:
   1181     return self._run_predict()

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1219, in Trainer._run_evaluate(self)
   1214 self._evaluation_loop.trainer = self
   1216 with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
   1217     self.accelerator, self._inference_mode
   1218 ):
-> 1219     eval_loop_results = self._evaluation_loop.run()
   1221 # remove the tensors from the eval results
   1222 for result in eval_loop_results:

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, **kwargs)
    197 try:
    198     self.on_advance_start(*args, **kwargs)
--> 199     self.advance(*args, **kwargs)
    200     self.on_advance_end()
    201     self._restarting = False

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152, in EvaluationLoop.advance(self, *args, **kwargs)
    150 if self.num_dataloaders > 1:
    151     kwargs["dataloader_idx"] = dataloader_idx
--> 152 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
    154 # store batch level output per dataloader
    155 self._outputs.append(dl_outputs)

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199, in Loop.run(self, *args, **kwargs)
    197 try:
    198     self.on_advance_start(*args, **kwargs)
--> 199     self.advance(*args, **kwargs)
    200     self.on_advance_end()
    201     self._restarting = False

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137, in EvaluationEpochLoop.advance(self, data_fetcher, dl_max_batches, kwargs)
    134 self.batch_progress.increment_started()
    136 # lightning module methods
--> 137 output = self._evaluation_step(**kwargs)
    138 output = self._evaluation_step_end(output)
    140 self.batch_progress.increment_processed()

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234, in EvaluationEpochLoop._evaluation_step(self, **kwargs)
    223 """The evaluation step (validation_step or test_step depending on the trainer's state).
    224 
    225 Args:
   (...)
    231     the outputs of the step
    232 """
    233 hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 234 output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
    236 return output

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
   1482     return
   1484 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1485     output = fn(*args, **kwargs)
   1487 # restore current_fx when nested context
   1488 pl_module._current_fx_name = prev_fx_name

File myenvs/condaenv/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:399, in Strategy.test_step(self, *args, **kwargs)
    397 with self.precision_plugin.test_step_context():
    398     assert isinstance(self.model, TestStep)
--> 399     return self.model.test_step(*args, **kwargs)

Cell In[4], line 47, in LitModel.test_step(self, batch, batch_idx)
     45 x.requires_grad = True
     46 print(x.requires_grad)
---> 47 shap_values = self.explainer.shap_values(x)

File myenvs/condaenv/lib/python3.10/site-packages/shap/explainers/_deep/__init__.py:124, in Deep.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
     90 def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True):
     91     """ Return approximate SHAP values for the model applied to the data given by X.
     92 
     93     Parameters
   (...)
    122         were chosen as "top".
    123     """
--> 124     return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)

File myenvs/condaenv/lib/python3.10/site-packages/shap/explainers/_deep/deep_pytorch.py:185, in PyTorchDeep.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
    183 # run attribution computation graph
    184 feature_ind = model_output_ranks[j, i]
--> 185 sample_phis = self.gradient(feature_ind, joint_x)
    186 # assign the attributions to the right part of the output arrays
    187 if self.interim:

File myenvs/condaenv/lib/python3.10/site-packages/shap/explainers/_deep/deep_pytorch.py:121, in PyTorchDeep.gradient(self, idx, inputs)
    119 else:
    120     for idx, x in enumerate(X):
--> 121         grad = torch.autograd.grad(selected, x,
    122                                    retain_graph=True if idx + 1 < len(X) else None,
    123                                    allow_unused=True)[0]
    124         if grad is not None:
    125             grad = grad.cpu().numpy()

File myenvs/condaenv/lib/python3.10/site-packages/torch/autograd/__init__.py:300, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
    298     return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_)
    299 else:
--> 300     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    301         t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
    302         allow_unused, accumulate_grad=False)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

如果你正在测试代码,我建议在 ipython notebook 中运行它。

正如您在上面看到的,我尝试在测试步骤中启用渐变,并将 torch.relu() 替换为 torch.nn.functional.softplus()。这些都没有用。

我希望 shap.DeepExplainer 对象按预期执行计算并输出值,而不是如上所示的错误

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

事实上,当我尝试使用一些随机值从外部调用模型的

explainer
对象时,它工作得很好。仅当我尝试在 test_step() 中执行相同操作时才会遇到此错误

这个

model.explainer.shap_values(torch.rand([1,1,28,28], device="cuda"))
非常好用。

python pytorch mnist pytorch-lightning shap
© www.soinside.com 2019 - 2024. All rights reserved.