我需要用我训练的卷积神经网络的数据测试结果编写一个文件。数据包括语音数据采集。文件格式需要是“文件名,预测”,但我很难提取文件名。我这样加载数据:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
TEST_DATA_PATH = ...
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = torchvision.datasets.MNIST(
root=TEST_DATA_PATH,
train=False,
transform=trans,
download=True
)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
我正在尝试按如下方式写入文件:
f = open("test_y", "w")
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader, 0):
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
file = os.listdir(TEST_DATA_PATH + "/all")[i]
format = file + ", " + str(predicted.item()) + '\n'
f.write(format)
f.close()
os.listdir(TESTH_DATA_PATH + "/all")[i]
的问题是它与test_loader
的加载文件顺序不同步。我能做什么?
嗯,这取决于你的
Dataset
是如何实现的。例如,在 torchvision.datasets.MNIST(...)
的情况下,您无法检索文件名,因为不存在单个样本的文件名(MNIST 样本以不同的方式加载)。
由于您没有展示您的
Dataset
实现,我将告诉您如何使用 torchvision.datasets.ImageFolder(...)
(或任何 torchvision.datasets.DatasetFolder(...)
)来完成此操作:
f = open("test_y", "w")
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader, 0):
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
sample_fname, _ = test_loader.dataset.samples[i]
f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()
__getitem__(self, index)
期间检索了文件的路径,特别是这里。
如果您实现了自己的
Dataset
(并且也许想支持 shuffle
和 batch_size > 1
),那么我会在 sample_fname
调用中返回 __getitem__(...)
并执行以下操作:
for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
# [...]
这样你就不需要关心
shuffle
。如果 batch_size
大于 1,则需要更改循环的内容以获得更通用的内容,例如:
f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
outputs = model(images)
pred = torch.max(outputs, 1)[1]
f.write("\n".join([
", ".join(x)
for x in zip(map(str, pred.cpu().tolist()), samples_fname)
]) + "\n")
f.close()
DataLoader
可以为您提供其内部数据集中的批次。
AS @Barriel 在单/多标签分类问题中提到,
DataLoader
没有图像文件名,只有代表图像的张量和类/标签。
但是,
DataLoader
加载对象时的构造函数可以接受一些小东西(如果您愿意,您可以与数据集一起打包目标/标签和文件名),甚至是数据框
这样,
DataLoader
可能会以某种方式抓住你需要的东西。
我在调试最近训练的网络时遇到了同样的问题。为了从原始问题中定义的名为
dataloader
的 DataLoader 实例获取文件名,我使用了: dataloader.dataset.imgs[0][0]