我定义了一个 GAN 模型,我想使用 FID 分数来评估它。我有 1 通道图像,它们是 mnist 数据集,但此方法需要 3 通道图像。我该如何解决这个问题?
在评估之前尝试将其分成 3 个通道。
import torch
import torchvision
from torcheval import metrics
# Load the MNIST dataset
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
# Convert the 1 channel images to 3 channel images
mnist_dataset.data = mnist_dataset.data.unsqueeze(1)
mnist_dataset.data = mnist_dataset.data.repeat(1, 3, 1, 1)
# Calculate the FID score
fid_score = metrics.FrechetInceptionDistance()(mnist_dataset.data)
# Evaluate the FID score
print('FID score:', fid_score)`enter code here`