如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?

问题描述 投票:0回答:3

官方文档仅说明

>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)

这没有展示如何在框架中使用该指标。

我的尝试(方法不完整,只展示相关部分):

def __init__(...):
    self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)

def validation_step(self, batch, batch_index):
    ...
    log_probs = self.forward(orig_batch)
    loss = self._criterion(log_probs, label_batch)
   
    self.val_confusion.update(log_probs, label_batch)
    self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)

def validation_step_end(self, outputs):
    return outputs

def validation_epoch_end(self, outs):
    self.log('validation_confusion_epoch', self.val_confusion.compute())

在第 0 个纪元之后,这给出了

    Traceback (most recent call last):
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 521, in train
        self.train_loop.run_training_epoch()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 588, in run_training_epoch
        self.trainer.run_evaluation(test_mode=False)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 613, in run_evaluation
        self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 346, in log_evaluation_step_metrics
        self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 350, in __log_result_step_metrics
        cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 378, in update_logger_connector
        batch_log_metrics = self.get_latest_batch_log_metrics()
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 418, in get_latest_batch_log_metrics
        batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in run_batch_from_func_name
        results = [func(include_forked_originals=False) for func in results]
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in <listcomp>
        results = [func(include_forked_originals=False) for func in results]
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 122, in get_batch_log_metrics
        return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",
*args, **kwargs)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in run_latest_batch_metrics_with_func_name
        for dl_idx in range(self.num_dataloaders)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in <listcomp>
        for dl_idx in range(self.num_dataloaders)
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 100, in get_latest_from_func_name
        results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
      File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\core\step_result.py", line 298, in get_batch_log_metrics
        result[dl_key] = self[k]._forward_cache.detach()
    AttributeError: 'NoneType' object has no attribute 'detach'

                                                      

它确实在训练前通过了健全性验证检查。

失败发生在

validation_step_end
返回时。对我来说没什么意义。

使用指标的完全相同的方法效果很好且准确。

如何得到正确的混淆矩阵?

python deep-learning pytorch tensorboard pytorch-lightning
3个回答
13
投票

闪电>=2.0.0 已过时

您可以使用

self.logger.experiment.add_figure(*tag*, *figure*)
报告该数字。

变量

self.logger.experiment
实际上是一个
SummaryWriter
(来自PyTorch,而不是Lightning)。这个类有方法
add_figure
文档)。

您可以按如下方式使用它:(MNIST 示例)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = F.nll_loss(preds, y)
        return { 'loss': loss, 'preds': preds, 'target': y}

    def validation_epoch_end(self, outputs):
        preds = torch.cat([tmp['preds'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        confusion_matrix = pl.metrics.functional.confusion_matrix(preds, targets, num_classes=10)

        df_cm = pd.DataFrame(confusion_matrix.numpy(), index = range(10), columns=range(10))
        plt.figure(figsize = (10,7))
        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
        plt.close(fig_)
        
        self.logger.experiment.add_figure("Confusion matrix", fig_, self.current_epoch)

6
投票

这已经过时了。

查看更好的版本


这花了很多时间才找到。

这是我可以粘贴的最少代码,但仍然可读且可重现。

我不想将整个模型数据集和参数放在这里,因为它们对这个问题的读者不感兴趣,而且只是噪音。


也就是说,这是为每个纪元创建混淆矩阵并在 Tensorboard 中显示所需的代码

这是一个单帧示例:


import pytorch_lightning as pl
import seaborn as sn
import pandas as pd
import numpy as np
import io
import matplotlib.pyplot as plt
from PIL import Image

def __init__(self, config, trained_vae, latent_dim):
    self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
    self.logger: Optional[TensorBoardLogger] = None

def forward(self, x):
    ...
    return log_probs

def validation_step(self, batch, batch_index):
    if self._config.dataset == "mnist":
        orig_batch, label_batch = batch
        orig_batch = orig_batch.reshape(-1, 28 * 28)

    log_probs = self.forward(orig_batch)
    loss = self._criterion(log_probs, label_batch)

    self.val_confusion.update(log_probs, label_batch)
    return {"loss": loss, "labels": label_batch}

def validation_step_end(self, outputs):
    return outputs

def validation_epoch_end(self, outs):
    tb = self.logger.experiment

    # confusion matrix
    conf_mat = self.val_confusion.compute().detach().cpu().numpy().astype(np.int)
    df_cm = pd.DataFrame(
        conf_mat,
        index=np.arange(self._config.n_clusters),
        columns=np.arange(self._config.n_clusters))
    plt.figure()
    sn.set(font_scale=1.2)
    sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d')
    buf = io.BytesIO()
    
    plt.savefig(buf, format='jpeg')
    buf.seek(0)
    im = Image.open(buf)
    im = torchvision.transforms.ToTensor()(im)
    tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)

并致电培训师

logger = TensorBoardLogger(save_dir=tb_logs_folder, name='Classifier')
trainer = Trainer(
    deterministic=True,
    max_epochs=10,
    default_root_dir=classifier_checkpoints_path,
    logger=logger,
    gpus=1
)

4
投票

更新答案,2022 年 8 月


class IntHandler:
    def legend_artist(self, legend, orig_handle, fontsize, handlebox):
        x0, y0 = handlebox.xdescent, handlebox.ydescent
        text = plt.matplotlib.text.Text(x0, y0, str(orig_handle))
        handlebox.add_artist(text)
        return text



class LightningClassifier(LightningModule):
    ...

    def _common_step(self, batch, batch_nb, stage: str):
        assert stage in ("train", "val", "test")
        logger = self._logger
        augmented_image, labels = batch

        outputs, aux_outputs = self(augmented_image)
        loss = self._criterion(outputs, labels)

        return outputs, labels, loss

    def validation_step(self, batch, batch_nb):
        stage = "val"
        outputs, labels, loss = self._common_step(batch, batch_nb, stage=stage)
        self._common_log(loss, stage=stage)

        return {"loss": loss, "outputs": outputs, "labels": labels}


    def validation_epoch_end(self, outs):
        # see https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/docs/source/pages/lightning.rst
        # each forward pass, thus leading to wrong accumulation. In practice do the following:
        tb = self.logger.experiment  # noqa

        outputs = torch.cat([tmp['outputs'] for tmp in outs])
        labels = torch.cat([tmp['labels'] for tmp in outs])

        confusion = torchmetrics.ConfusionMatrix(num_classes=self.n_labels).to(outputs.get_device())
        confusion(outputs, labels)
        computed_confusion = confusion.compute().detach().cpu().numpy().astype(int)

        # confusion matrix
        df_cm = pd.DataFrame(
            computed_confusion,
            index=self._label_ind_by_names.values(),
            columns=self._label_ind_by_names.values(),
        )

        fig, ax = plt.subplots(figsize=(10, 5))
        fig.subplots_adjust(left=0.05, right=.65)
        sn.set(font_scale=1.2)
        sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d', ax=ax)
        ax.legend(
            self._label_ind_by_names.values(),
            self._label_ind_by_names.keys(),
            handler_map={int: IntHandler()},
            loc='upper left',
            bbox_to_anchor=(1.2, 1)
        )
        buf = io.BytesIO()

        plt.savefig(buf, format='jpeg', bbox_inches='tight')
        buf.seek(0)
        im = Image.open(buf)
        im = torchvision.transforms.ToTensor()(im)
        tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)

输出:

也基于这个

© www.soinside.com 2019 - 2024. All rights reserved.