我有一个Pytorch dataloading的数据类。它从hdf5档案中检索项目(15万个样本),然后我将其输入到dataloader中,并训练一个小型的一个隐藏层自动编码器。然而,当我尝试训练我的网络时,什么都没有发生,没有GPU利用率。我正在使用,4个CPU和2个GPU开始。
我的批次大小是128,我开始训练时使用8个工人。
我也遵循了Pytorchs的数据并行教程。下面是我的hdf5数据类的代码。
import torch.multiprocessing as mp
mp.set_start_method('fork')
from torch.utils import data
import h5py
import time
class Features_Dataset(data.Dataset):
def __init__(self, file_path, phase):
self.file_path = file_path
self.archive = None
self.phase = phase
with h5py.File(file_path, 'r', libver='latest', swmr=True) as f:
self.length = len(f[(self.phase) + '_labels'])
def _get_archive(self):
if self.archive is None:
self.archive = h5py.File(self.file_path, 'r', libver='latest', swmr=True)
assert self.archive.swmr_mode
return self.archive
def __getitem__(self, index):
archive = self._get_archive()
label = archive[str(self.phase) + '_labels']
datum = archive[str(self.phase) + '_all_arrays']
path = archive[str(self.phase) + '_img_paths']
return datum[index], label[index], path[index]
def __len__(self):
return self.length
def close(self):
self.archive.close()
if __name__ == '__main__':
train_dataset = Features_Dataset(file_path= "featuresdata/train.hdf5", phase= 'train')
trainloader = data.DataLoader(train_dataset, num_workers=8, batch_size=1)
print(len(trainloader))
myStart = time.time()
for i, (data, label, path) in enumerate(trainloader):
print(path)
这是我的自动编码器类。
import torch
import torch.nn as nn
class AutoEncoder(nn.Module):
def __init__(self, n_embedded):
super(AutoEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(6144, n_embedded))
self.decoder = nn.Sequential(nn.Linear(n_embedded, 6144))
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
这是我如何初始化模型
device = torch.device("cuda")
# Initialize / load checkpoint
model = AutoEncoder(2048)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model= nn.DataParallel(model)
model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),weight_decay=1e-5)
我确保我的输入也被输入到设备中.
会不会是检索批次的速度问题?关于hdf5特征数据集类,我试图在不使用hdf5特征数据集的情况下懒惰地加载hdf5数据集。__init__
然而,我想也许计算数据集的长度可能是个问题......。
这个问题可能是懒惰加载造成的瓶颈。你可以尝试在 启动 的dataloader(如果你有足够的资源)。然后,在 获取项目 只要从已经准备好的列表中返回datum[index]、label[index]、path[index]即可。希望能帮到你。希望能帮助到你。