有没有一种有效的方法从 h5py 中的 3D 数据集加载 2D?

问题描述 投票:0回答:1

我是一名研究生,研究 MR 到 CT 的翻译。

我在 h5py 中创建了 MR、CT 和 MASK 组,并在每组中以数据集的形式保存了每位患者的 3D 数据。 当再次**加载数据时,我实现了代码来访问每个组中的数据集并检索每个切片(二维数据)(下面共享的代码) 之所以没有在切片维度上一次性堆叠,是因为每个患者的宽度和高度都不同,所以不能一次性堆叠。

但是,它不仅占用大量CPU,而且加载也需要时间,因此网络训练需要很长时间。

有没有办法改进这段代码或改进数据存储和加载的方式? 感谢您花时间阅读。

class dataset_synthRAD_FLY(Dataset):
    def __init__(
        self,
        data_dir: str,
        rand_crop: bool = False,
        misalign_x: float = 0.0,
        misalign_y: float = 0.0,
        degree: float = 0.0,
        motion_prob: float = 0.0,
        deform_prob: float = 0.0,
        aug: bool = False,
        reverse: bool = False,
        return_msk: bool = False,
        crop_size=256,
    ):
        super().__init__()
        self.rand_crop = rand_crop
        self.data_dir = data_dir
        
        # Each patient has a different number of slices        
        self.patient_keys = []
        with h5py.File(self.data_dir, 'r') as file:
            self.patient_keys = list(file['MR'].keys())
            self.slice_counts = [file['MR'][key].shape[-1] for key in self.patient_keys]
            self.cumulative_slice_counts = np.cumsum([0] + self.slice_counts)
        
        if return_msk:
            self.aug_func = Compose(
                [
                    RandFlipd(keys=["A", "B", "M"], prob=0.5, spatial_axis=[0, 1]),
                    RandRotate90d(keys=["A", "B", "M"], prob=0.5, spatial_axes=[0, 1]),
                ]
            )

        else:
            self.aug_func = Compose(
                [
                    RandFlipd(keys=["A", "B"], prob=0.5, spatial_axis=[0, 1]),
                    RandRotate90d(keys=["A", "B"], prob=0.5, spatial_axes=[0, 1]),
                ]
            )

        self.misalign_x = misalign_x
        self.misalign_y = misalign_y
        self.degree = degree
        self.motion_prob = motion_prob
        self.deform_prob = deform_prob
        self.aug = aug
        self.reverse = reverse
        self.return_msk = return_msk
        self.crop_size = crop_size

        self.h5file = h5py.File(self.data_dir, 'r') 
        os.environ["HDF5_USE_FILE_LOCKING"] = "TRUE"

    def __len__(self):
        os.environ["HDF5_USE_FILE_LOCKING"] = "TRUE"
        return self.cumulative_slice_counts[-1]

    def __del__(self): 
        self.h5file.close()
    
    def __getitem__(self, idx):
        patient_idx = np.searchsorted(self.cumulative_slice_counts, idx+1) - 1
        slice_idx = idx - self.cumulative_slice_counts[patient_idx]
        patient_key = self.patient_keys[patient_idx]
        
        A = self.h5file["MR"][patient_key][..., slice_idx]
        if (
            self.deform_prob > 0
            and idx > 2
            and idx < self.__len__() - 2
            and np.random.rand() < self.deform_prob
        ):
            slice_idx_new = np.random.randint(slice_idx + 1, slice_idx + 2)
            A = self.h5file["MR"][patient_key][..., slice_idx]
            B = self.h5file["CT"][patient_key][..., slice_idx_new]
        else:
            B = self.h5file["CT"][patient_key][..., slice_idx]

        if self.return_msk:
            M = self.h5file["MASK"][patient_key][..., slice_idx]
            M = torch.from_numpy(M[None])

        A = A.astype(np.float32)
        B = B.astype(np.float32)

        A = torch.from_numpy(A[None])
        B = torch.from_numpy(B[None])

        # Create a dictionary for the data
        if self.return_msk:
            data_dict = {"A": A, "B": B, "M": M}
        else:
            data_dict = {"A": A, "B": B}

        # Apply the random flipping
        if self.aug:
            data_dict = self.aug_func(data_dict)

        A = data_dict["A"]
        A = convert_to_tensor(A)
        B = data_dict["B"]
        B = convert_to_tensor(B)

        if self.return_msk:
            M = data_dict["M"]
            M = convert_to_tensor(M)

        # Perform misalignment (Rigid)
        # First: translation
        if self.misalign_x == 0 and self.misalign_y == 0 and self.degree == 0:
            pass
        else:
            A, B = translate_images(A, B, self.misalign_x, self.misalign_y, self.degree)

        if np.random.rand() < self.motion_prob:
            if self.reverse:
                A = motion_artifact(A)  # A is the label
            else:
                B = motion_artifact(B)  # B is the label

        A, B = torch.clamp(A, min=-1, max=1), torch.clamp(
            B, min=-1, max=1
        )  # make sure -1, 1
        
        if self.rand_crop:
            if self.return_msk:
                A, B, M = random_crop2(A, B, M, (self.crop_size, self.crop_size))

            else:
                A, B = random_crop(A, B, (self.crop_size, self.crop_size))

        if self.reverse:
            if self.return_msk:
                return B, A, M
            else:
                return B, A
        else:
            if self.return_msk:
                return A, B, M
            else:
                return A, B
  • 在数据集的init中尽可能多地编写代码。
  • 另存为 h5 时请勿使用 gzip。使加载速度更快。
dataset h5py
1个回答
0
投票

下面的代码演示了将 HDF5 数据集切片读取到 NumPy 数组的 3 种方法在 I/O 性能方面的差异。方法有:

  1. 从未分块的数据集中读取图像数组切片(OP中的方法)
  2. 将整个未分块的数据集读取到数组,然后读取图像数组切片(我的方法中的方法) 评论)
  3. 从分块数据集中读取图像数组切片

方法3与方法1的读取过程相同,但数据集必须使用分块存储创建。块形状与 1 个图像切片匹配:

(img_x,img_y,1)

计时结果显示,方法 2 和方法 3 均明显快于方法 1。方法 2 快 20 倍,方法 3 快 11 倍。

计时数据

  1. 从 UNCHUNCKED 数据集中读取数组切片:16.53 秒
  2. 将 UNCHUNCKED 数据集读取到数组,然后切片:0.81 秒
  3. 从 CHUNCKED 数据集中读取数组切片:1.52 秒

代码
下面构建示例 HDF5,然后使用上述 3 种方法进行读取。它创建一个 2.4 GB 的文件。如果您想要更大(或更小)的测试文件,您可以调整顶部的变量。

n_patients = 100
img_x, img_y = 256, 256
n_slices = 25

# Create file with 2 groups: MR1 unchunked datasets and MR2 chunked datasets
start = time.time()
with h5py.File('SO_77378615.h5','w') as h5f: 
    mr_grp1 = h5f.create_group('MR1')
    mr_grp2 = h5f.create_group('MR2')
    for i in range(1,n_patients+1):
        patient = f'patient_{i:03}'
        img_arr = np.random.random(img_x*img_y*n_slices).reshape(img_x,img_y,n_slices)
        mr_grp1.create_dataset(patient,data=img_arr)
        mr_grp2.create_dataset(patient,data=img_arr,chunks=(img_x,img_y,1))
print(f'time to create file with 2 datasets: {time.time()-start:.03f}')
                
# Open file and read all patient data
with h5py.File('SO_77378615.h5') as h5f:
     mr_grp1 = h5f['MR1'] # creates MR group object for reference
     start = time.time()
     for patient in mr_grp1:
         for slice in range(mr_grp1[patient].shape[2]):
             arr = mr_grp1[patient][:,:,slice]
     print(f'time to read array slices from UNCHUNCKED dataset: {time.time()-start:.03f}')

     start = time.time()
     for patient in mr_grp1:
         pat_arr = mr_grp1[patient][()]
         for slice in range(pat_arr.shape[2]):
                 arr = pat_arr[:,:,slice]
     print(f'time to read UNCHUNCKED dataset to array, then slice: {time.time()-start:.03f}')

     mr_grp2 = h5f['MR2'] # creates MR group object for reference
     start = time.time()
     for patient in mr_grp2:
         for slice in range(mr_grp2[patient].shape[2]):
             arr = mr_grp2[patient][:,:,slice]
     print(f'time to read array slices from CHUNCKED dataset: {time.time()-start:.03f}')
© www.soinside.com 2019 - 2024. All rights reserved.