如何从DataLoader获取样本的文件名?

问题描述 投票:0回答:4

我需要用我训练的卷积神经网络的数据测试结果编写一个文件。数据包括语音数据采集。文件格式需要是“文件名,预测”,但我很难提取文件名。我这样加载数据:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

我正在尝试按如下方式写入文件:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

os.listdir(TESTH_DATA_PATH + "/all")[i]
的问题是它与
test_loader
的加载文件顺序不同步。我能做什么?

python machine-learning pytorch torchvision
4个回答
11
投票

嗯,这取决于你的

Dataset
是如何实现的。例如,在
torchvision.datasets.MNIST(...)
的情况下,您无法检索文件名,因为不存在单个样本的文件名(MNIST 样本以不同的方式加载)。

由于您没有展示您的

Dataset
实现,我将告诉您如何使用
torchvision.datasets.ImageFolder(...)
(或任何
torchvision.datasets.DatasetFolder(...)
)来完成此操作:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

您可以看到在

__getitem__(self, index)
期间检索了文件的路径,特别是这里

如果您实现了自己的

Dataset
(并且也许想支持
shuffle
batch_size > 1
),那么我会在
sample_fname
调用中返回
__getitem__(...)
并执行以下操作:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

这样你就不需要关心

shuffle
。如果
batch_size
大于 1,则需要更改循环的内容以获得更通用的内容,例如:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()

1
投票

在一般情况下,

DataLoader
可以为您提供其内部数据集中的批次。

AS @Barriel 在单/多标签分类问题中提到,

DataLoader
没有图像文件名,只有代表图像的张量和类/标签。

但是,

DataLoader
加载对象时的构造函数可以接受一些小东西(如果您愿意,您可以与数据集一起打包目标/标签和文件名),甚至是数据框

这样,

DataLoader
可能会以某种方式抓住你需要的东西。


1
投票

如果您使用 PyCharm 或任何具有调试工具的 IDE,请使用它来查看您的 data_loader 内部,希望您可以看到文件名列表,就像我的情况一样。

就我而言, 我的 data_loader 是通过 mmegmentation 创建的。


0
投票

我在调试最近训练的网络时遇到了同样的问题。为了从原始问题中定义的名为

dataloader
的 DataLoader 实例获取文件名,我使用了:
dataloader.dataset.imgs[0][0]

© www.soinside.com 2019 - 2024. All rights reserved.