如何使用数据加载器解决此问题?

问题描述 投票:0回答:1

我正在构建一些数据加载器来训练和测试机器学习模型。 我有一个名为“array”的元组列表,如下所示:

(Data(x=[468, 2], edge_index=[2, 1322], y=0, edge_weight=[1322]), 'morphed_img027485_img054553.png')
(Data(x=[468, 2], edge_index=[2, 1322], y=0, edge_weight=[1322]), 'morphed_img031737_img054553.png')

我像这样创建数据加载器:

data_loader = create_dataloader(array, batch_size=60)
save_dataloader(data_loader, 'NameofDataLoader')

输出不是我所期望的,但它将所有数据合并到一个 DataBatch 中,如下所示:

[DataBatch(x=[936, 2], edge_index=[2, 2644], y=[2], edge_weight=[2644], batch=[936], ptr=[3]), ('morphed_img031737_img054553.png', 'morphed_img027485_img054553.png')]

为什么?我怎样才能有一个数据加载器,将所有数据像数组一样分开?

python machine-learning torch dataloader pytorch-geometric
1个回答
-2
投票

您似乎正在使用 PyTorch 的 DataLoader 来处理数据。您观察到的所有数据合并到单个 DataBatch 中的行为是因为 DataLoader 根据您指定的 batch_size 参数将数据分组为批次。

如果您想维护列表中的各个元组而不是将它们分组为批次,则需要将batch_size参数设置为1或者根本不使用DataLoader。以下是修改代码的方法:

array = [
    (Data(x=[468, 2], edge_index=[2, 1322], y=0, edge_weight=[1322]), 'morphed_img027485_img054553.png'),
    (Data(x=[468, 2], edge_index=[2, 1322], y=0, edge_weight=[1322]), 'morphed_img031737_img054553.png')
]

# Creating a dataloader with batch size 1
data_loader = create_dataloader(array, batch_size=1)

# Saving each item separately
for idx, (data, filename) in enumerate(data_loader):
    save_dataloader([(data, filename)], f'NameofDataLoader_{idx}')

# Or simply save the array without using DataLoader
for idx, (data, filename) in enumerate(array):
    save_dataloader([(data, filename)], f'NameofDataLoader_{idx}')

这样,列表中的每个元组都将被单独处理,并且列表中的每个项目都将拥有单独的 DataLoader。

© www.soinside.com 2019 - 2024. All rights reserved.