我有一个自定义数据集,其中数据已存储为表单中的字典,
文件1.pt {张量1:张量2} 文件2.pt {张量1:张量2}。依此类推,大约 50k 个文件聚集 20GB 卷。
Tensor1 是数据,Tensor2 是其标签。保留张量的最佳方法是什么,或者最好作为张量加载到数据加载器中,而不是所有文件中的 dict_keys 或 dict_values 类型。
我目前已将所有词典加载到数据集中。
使用“dict.keys()”和“dict.values()”,需要转换为列表然后进一步处理。我正在寻找更快的东西。
实现覆盖 getitem 和 len 方法的自定义数据集类。通过这样做,您可以确保数据是延迟加载的,而不是一次性加载的,这应该更节省内存。
class CustomTensorDataset(Dataset):
def __init__(self, file_paths):
self.file_paths = file_paths
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
data_dict = torch.load(file_path)
return data_dict["tensor1"], data_dict["tensor2"]
# Collect file paths
folder_path = "./your_data_folder"
file_paths = [os.path.join(folder_path, fname) for fname in os.listdir(folder_path) if fname.endswith('.pt')]
# Initialize custom dataset and DataLoader
custom_dataset = CustomTensorDataset(file_paths)
data_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)