无法对图像和掩模应用相同的变换以进行数据增强

问题描述 投票:0回答:1

我正在尝试使用 pytorch 训练 U-Net 模型构建。对于这种情况,我构建了数据集并在图像和掩模中应用了数据增强转换。情况是我想对两者应用相同的变换,这意味着,如果我将图像旋转一定的度数,我希望蒙版旋转相同的度数,这就是我的问题。图像和蒙版的旋转量不同。

我留下以下代码:

数据集

import torch
from torch.utils.data import Dataset
import os

class INBreastDataset2012(Dataset):
    def __init__(self, dict_dir, transform=None):
        self.dict_dir = dict_dir
        self.data = os.listdir(self.dict_dir)
        self.transform = transform



    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        dict_path = os.path.join(self.dict_dir, self.data[index])
        patient_dict = torch.load(dict_path)
        image = patient_dict['image'].unsqueeze(0)
        mass_mask = patient_dict['mass_mask'].unsqueeze(0)
        mass_mask[mass_mask > 1.0] = 1.0


        if self.transform is not None:
            image = self.transform(image)
            mass_mask = self.transform(mass_mask)
            
        
        return image, mass_mask


“训练”(此时并不是真正的训练,只是数据加载器带来的信息的可视化)

from dataset import INBreastDataset2012
from torchvision.transforms import v2 as T
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

train_dir = r'directory\of\training images and masks'
test_dir = r'directory\of\testing images and masks'

train_transform = T.Compose(
        [
            T.RandomRotation(degrees=35, expand=True, fill=255.0),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),

        ]
    )

train_data = INBreastDataset2012(train_dir,transform=train_transform)
test_data = INBreastDataset2012(test_dir)

train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

plt.figure(figsize=(12,12))
for i, (imagen,mascara) in enumerate(train_dataloader):
    ax = plt.subplot(2,4,i+1)
    ax.title.set_text(f'imagen {i+1}')
    plt.imshow(imagen.squeeze(), cmap='gray')
    ax = plt.subplot(2,4,i+3)
    ax.title.set_text(f'mascara de imagen {i+1}')
    plt.imshow(mascara.squeeze(), cmap='gray')
    if i == 1:
        break

结果 Result transformation of images and masks

我还要补充一点,我已经尝试过使用 albumentations 和 torchvision.transforms v1。在 pytorch 和 youtube 视频的示例中,他们似乎做了和我一样的事情。

有人可以帮助我看看我做错了什么,或者有一个解决方案来确保转换相同,我将不胜感激。

如果需要任何额外信息,请询问。这是我的第一篇文章,所以我可能错过了一些东西。 先谢谢你了

python pytorch dataset pytorch-dataloader data-augmentation
1个回答
0
投票

您可以尝试沿通道维度连接图像和掩模,运行变换,然后将结果拆分回两个张量。

...

if self.transform is not None:
    #Concatenate along channel dimension
    image_and_mask = torch.cat([image, mask], dim=1)
 
    #Transform together
    transformed = self.transform(image_and_mask)
    
    #Slice the tensors out
    image = transformed[:, :image.shape[1], ...]
    mass_mask = transformed[:, image.shape[1]:, ...]

...
© www.soinside.com 2019 - 2024. All rights reserved.