Pytorch 指标:多标签混淆矩阵不同设备错误

问题描述 投票:0回答:1

我正在尝试使用 torchmetrics 计算多标签输出的混淆矩阵,但出现以下错误:

File "/home/antpc/.local/lib/python3.8/site-packages/torchmetrics/metric.py", line 394, in wrapped_func
    raise RuntimeError(
RuntimeError: Encountered different devices in metric calculation (see stacktrace for details).This could be due to the metric class not being on the same device as input.Instead of `metric=ConfusionMatrix(...)` try to do `metric=ConfusionMatrix(...).to(device)` where device corresponds to the device of the input.

我的代码:

from torchmetrics import ConfusionMatrix
def calculate_metrics(predictions, targets):
    cm = ConfusionMatrix(num_classes=34, multilabel=True)
    matrix = cm(predictions, targets)
    return matrix

然后我尝试将代码更改为:

from torchmetrics import ConfusionMatrix
def calculate_metrics(predictions, targets):
    cm = ConfusionMatrix(num_classes=34, multilabel=True).to(device='cpu')
    matrix = cm(predictions.detach().cpu(), targets.detach().cpu())
    return matrix

仍然显示相同的错误。谁能帮我解决这个问题吗?

请不要建议我使用

sklearn.metrics.multilabel_confusion_matrix

pytorch metrics confusion-matrix multilabel-classification pytorch-lightning
1个回答
0
投票

此错误不是由指标引起的,而是由于使用多个 GPU 而由 Pytorch 闪电引起的。

我以前的代码:

model = ModelClassifier()
trainer = pl.Trainer(strategy='dp', max_epochs=150, gpus=8, fast_dev_run=True)
trainer.fit(model, train_loader)

更改

strategy
并删除
fast_dev_run=True

后错误已解决

工作代码:

model = ModelClassifier()
trainer = pl.Trainer(strategy='ddp', max_epochs=150, gpus=8)
trainer.fit(model, train_loader)
© www.soinside.com 2019 - 2024. All rights reserved.