我是一名研究生,研究 MR 到 CT 的翻译。
我在 h5py 中创建了 MR、CT 和 MASK 组,并在每组中以数据集的形式保存了每位患者的 3D 数据。 当再次**加载数据时,我实现了代码来访问每个组中的数据集并检索每个切片(二维数据)(下面共享的代码) 之所以没有在切片维度上一次性堆叠,是因为每个患者的宽度和高度都不同,所以不能一次性堆叠。
但是,它不仅占用大量CPU,而且加载也需要时间,因此网络训练需要很长时间。
有没有办法改进这段代码或改进数据存储和加载的方式? 感谢您花时间阅读。
class dataset_synthRAD_FLY(Dataset):
def __init__(
self,
data_dir: str,
rand_crop: bool = False,
misalign_x: float = 0.0,
misalign_y: float = 0.0,
degree: float = 0.0,
motion_prob: float = 0.0,
deform_prob: float = 0.0,
aug: bool = False,
reverse: bool = False,
return_msk: bool = False,
crop_size=256,
):
super().__init__()
self.rand_crop = rand_crop
self.data_dir = data_dir
# Each patient has a different number of slices
self.patient_keys = []
with h5py.File(self.data_dir, 'r') as file:
self.patient_keys = list(file['MR'].keys())
self.slice_counts = [file['MR'][key].shape[-1] for key in self.patient_keys]
self.cumulative_slice_counts = np.cumsum([0] + self.slice_counts)
if return_msk:
self.aug_func = Compose(
[
RandFlipd(keys=["A", "B", "M"], prob=0.5, spatial_axis=[0, 1]),
RandRotate90d(keys=["A", "B", "M"], prob=0.5, spatial_axes=[0, 1]),
]
)
else:
self.aug_func = Compose(
[
RandFlipd(keys=["A", "B"], prob=0.5, spatial_axis=[0, 1]),
RandRotate90d(keys=["A", "B"], prob=0.5, spatial_axes=[0, 1]),
]
)
self.misalign_x = misalign_x
self.misalign_y = misalign_y
self.degree = degree
self.motion_prob = motion_prob
self.deform_prob = deform_prob
self.aug = aug
self.reverse = reverse
self.return_msk = return_msk
self.crop_size = crop_size
self.h5file = h5py.File(self.data_dir, 'r')
os.environ["HDF5_USE_FILE_LOCKING"] = "TRUE"
def __len__(self):
os.environ["HDF5_USE_FILE_LOCKING"] = "TRUE"
return self.cumulative_slice_counts[-1]
def __del__(self):
self.h5file.close()
def __getitem__(self, idx):
patient_idx = np.searchsorted(self.cumulative_slice_counts, idx+1) - 1
slice_idx = idx - self.cumulative_slice_counts[patient_idx]
patient_key = self.patient_keys[patient_idx]
A = self.h5file["MR"][patient_key][..., slice_idx]
if (
self.deform_prob > 0
and idx > 2
and idx < self.__len__() - 2
and np.random.rand() < self.deform_prob
):
slice_idx_new = np.random.randint(slice_idx + 1, slice_idx + 2)
A = self.h5file["MR"][patient_key][..., slice_idx]
B = self.h5file["CT"][patient_key][..., slice_idx_new]
else:
B = self.h5file["CT"][patient_key][..., slice_idx]
if self.return_msk:
M = self.h5file["MASK"][patient_key][..., slice_idx]
M = torch.from_numpy(M[None])
A = A.astype(np.float32)
B = B.astype(np.float32)
A = torch.from_numpy(A[None])
B = torch.from_numpy(B[None])
# Create a dictionary for the data
if self.return_msk:
data_dict = {"A": A, "B": B, "M": M}
else:
data_dict = {"A": A, "B": B}
# Apply the random flipping
if self.aug:
data_dict = self.aug_func(data_dict)
A = data_dict["A"]
A = convert_to_tensor(A)
B = data_dict["B"]
B = convert_to_tensor(B)
if self.return_msk:
M = data_dict["M"]
M = convert_to_tensor(M)
# Perform misalignment (Rigid)
# First: translation
if self.misalign_x == 0 and self.misalign_y == 0 and self.degree == 0:
pass
else:
A, B = translate_images(A, B, self.misalign_x, self.misalign_y, self.degree)
if np.random.rand() < self.motion_prob:
if self.reverse:
A = motion_artifact(A) # A is the label
else:
B = motion_artifact(B) # B is the label
A, B = torch.clamp(A, min=-1, max=1), torch.clamp(
B, min=-1, max=1
) # make sure -1, 1
if self.rand_crop:
if self.return_msk:
A, B, M = random_crop2(A, B, M, (self.crop_size, self.crop_size))
else:
A, B = random_crop(A, B, (self.crop_size, self.crop_size))
if self.reverse:
if self.return_msk:
return B, A, M
else:
return B, A
else:
if self.return_msk:
return A, B, M
else:
return A, B
下面的代码演示了将 HDF5 数据集切片读取到 NumPy 数组的 3 种方法在 I/O 性能方面的差异。方法有:
方法3与方法1的读取过程相同,但数据集必须使用分块存储创建。块形状与 1 个图像切片匹配:
(img_x,img_y,1)
。
计时结果显示,方法 2 和方法 3 均明显快于方法 1。方法 2 快 20 倍,方法 3 快 11 倍。
计时数据
代码
下面构建示例 HDF5,然后使用上述 3 种方法进行读取。它创建一个 2.4 GB 的文件。如果您想要更大(或更小)的测试文件,您可以调整顶部的变量。
n_patients = 100
img_x, img_y = 256, 256
n_slices = 25
# Create file with 2 groups: MR1 unchunked datasets and MR2 chunked datasets
start = time.time()
with h5py.File('SO_77378615.h5','w') as h5f:
mr_grp1 = h5f.create_group('MR1')
mr_grp2 = h5f.create_group('MR2')
for i in range(1,n_patients+1):
patient = f'patient_{i:03}'
img_arr = np.random.random(img_x*img_y*n_slices).reshape(img_x,img_y,n_slices)
mr_grp1.create_dataset(patient,data=img_arr)
mr_grp2.create_dataset(patient,data=img_arr,chunks=(img_x,img_y,1))
print(f'time to create file with 2 datasets: {time.time()-start:.03f}')
# Open file and read all patient data
with h5py.File('SO_77378615.h5') as h5f:
mr_grp1 = h5f['MR1'] # creates MR group object for reference
start = time.time()
for patient in mr_grp1:
for slice in range(mr_grp1[patient].shape[2]):
arr = mr_grp1[patient][:,:,slice]
print(f'time to read array slices from UNCHUNCKED dataset: {time.time()-start:.03f}')
start = time.time()
for patient in mr_grp1:
pat_arr = mr_grp1[patient][()]
for slice in range(pat_arr.shape[2]):
arr = pat_arr[:,:,slice]
print(f'time to read UNCHUNCKED dataset to array, then slice: {time.time()-start:.03f}')
mr_grp2 = h5f['MR2'] # creates MR group object for reference
start = time.time()
for patient in mr_grp2:
for slice in range(mr_grp2[patient].shape[2]):
arr = mr_grp2[patient][:,:,slice]
print(f'time to read array slices from CHUNCKED dataset: {time.time()-start:.03f}')