拆开字典

问题描述 投票:0回答:1

我有一个自定义数据集,其中数据已存储为表单中的字典,

文件1.pt {张量1:张量2} 文件2.pt {张量1:张量2}。依此类推,大约 50k 个文件聚集 20GB 卷。

Tensor1 是数据,Tensor2 是其标签。保留张量的最佳方法是什么,或者最好作为张量加载到数据加载器中,而不是所有文件中的 dict_keys 或 dict_values 类型。

我目前已将所有词典加载到数据集中。

使用“dict.keys()”和“dict.values()”,需要转换为列表然后进一步处理。我正在寻找更快的东西。

python dictionary pytorch unpack
1个回答
0
投票

实现覆盖 getitemlen 方法的自定义数据集类。通过这样做,您可以确保数据是延迟加载的,而不是一次性加载的,这应该更节省内存。

class CustomTensorDataset(Dataset):
def __init__(self, file_paths):
    self.file_paths = file_paths

def __len__(self):
    return len(self.file_paths)

def __getitem__(self, idx):
    file_path = self.file_paths[idx]
    data_dict = torch.load(file_path)
    return data_dict["tensor1"], data_dict["tensor2"]

# Collect file paths
folder_path = "./your_data_folder"
file_paths = [os.path.join(folder_path, fname) for fname in os.listdir(folder_path) if fname.endswith('.pt')]

# Initialize custom dataset and DataLoader
custom_dataset = CustomTensorDataset(file_paths)
data_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)
© www.soinside.com 2019 - 2024. All rights reserved.