如何使用带有3-D矩阵的pytorch DataLoader进行LSTM输入?

问题描述 投票:0回答:1

我有一个3-D(time_stepinputsizetotal_num)矩阵的数据集,它是一个.mat文件。我想使用DataLoader获取LSTM的输入数据集,其中batch_size为5.我的代码如下:

file_path = "…/database/frameLength100/notOverlap/a.mat"
mat_data = s.loadmat(file_path)
tensor_data = torch.from_numpy(mat_data[‘a’]) #Tensor

class CustomDataset(Dataset):

def __init__(self, tensor_data):
    self.tensor_data = tensor_data

def __getitem__(self, index):
    data = self.tensor_data[index]
    label = 1;
    return data, label

def __len__(self):
    return len(self.tensor_data)
custom_dataset = CustomDataset(tensor_data=tensor_data)
train_loader = DataLoader(dataset=custom_dataset, batch_size=5, shuffle=True)

我认为代码是错误的,但我不知道如何纠正它。令我困惑的是如何让DataLoader知道哪个维度是'total_num',这样我就可以获得批量大小为5的数据集。

pytorch
1个回答
0
投票

如果我理解正确,你希望批量发生在total_num维度上,i。即维度2。

您可以简单地使用维度来索引数据集,即将__getitem__更改为data = self.tensor_data[:, :, index],并相应地在__len__中,返回self.tensor_data.size(2)而不是len(self.tensor_data)。然后每批都有[time_step, inputsize, 5]的大小。

© www.soinside.com 2019 - 2024. All rights reserved.