名称错误:名称“pil_mask”未定义

问题描述 投票:0回答:1

这是我的代码。我已经定义了这样的各种操作:

def identity(pil_img, pil_mask, _):
    return pil_img, pil_mask

def autocontrast(pil_img, pil_mask, _):
    return ImageOps.autocontrast(pil_img), pil_mask


def equalize(pil_img, pil_mask, _):
    return ImageOps.equalize(pil_img), pil_mask



def rotate(pil_img, pil_mask, level):
    degrees = int_parameter(level, min_max_vals.rotate.max)
    if np.random.uniform() > 0.5:
        degrees = -degrees
    return pil_img.rotate(degrees, resample=Image.BILINEAR), pil_mask.rotate(degrees, resample=Image.BILINEAR)

就像上面的。

现在我想使用 PRIME 增强(最大熵的 PRimitives):

但我收到错误:

    aug_x += fn(x_tensor, pil_mask, _) * mask_t[:, i] * weight
NameError: name 'pil_mask' is not defined
and this is the PRIME code:

augmentations = [
    (identity, 1.0)
    ]
class PRIMEAugModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.augmentations = augmentations
        self.num_transforms = len(augmentations)


    def forward(self, x, mask_t):
        x_tensor = torch.from_numpy(x)
        aug_x = torch.zeros_like(x_tensor)
        for i in range(self.num_transforms):
            fn, weight = self.augmentations[i]
            if fn.__name__ == 'identity':
                aug_x += fn(x_tensor, pil_mask, _) * mask_t[:, i] * weight
            else:
                aug_x += fn(x_tensor, pil_mask) * mask_t[:, i] * weight
        return aug_x

我很困惑应该在哪里以及如何定义 PIL_mask

deep-learning pytorch image-augmentation
1个回答
0
投票
In the code snippet above, pil_mask is created as a PIL Image, then converted to a NumPy array (pil_mask_np). You can now use pil_mask_np wherever you need it in your code.



 import torch
 import numpy as np
 import PIL.Image as Image
    
 def identity(x_tensor, pil_mask, _):
    return x_tensor

 augmentations = [
   (identity, 1.0)
]

    pil_mask = Image.new('L', (256, 256), color=0)  # Example mask creation (grayscale image of zeros)
pil_mask_np = np.array(pil_mask)

    class PRIMEAugModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.augmentations = augmentations
        self.num_transforms = len(augmentations)


    def forward(self, x, mask_t):
        x_tensor = torch.from_numpy(x)
        aug_x = torch.zeros_like(x_tensor)
        for i in range(self.num_transforms):
            fn, weight = self.augmentations[i]
            if fn.__name__ == 'identity':
                aug_x += fn(x_tensor, pil_mask_np, _) * mask_t[:, i] * weight
            else:
                aug_x += fn(x_tensor, pil_mask_np) * mask_t[:, i] * weight
        return aug_x
© www.soinside.com 2019 - 2024. All rights reserved.