使用Unet架构进行图像去噪

问题描述 投票:0回答:1

我是计算机视觉和深度学习的新手。我正在尝试使用 Resnet50 作为编码器来训练这个 Unet 模型https://github.com/kevinlu1211/pytorch-unet-resnet-50-encoder。我想以传递两个 RGB 图像的方式来实现它,这些图像首先由 resnet50 处理,然后在传递给解码器之前将各层连接起来。我尝试这样做,并将代码中的 n_classes 更改为 3 以输出 3 通道 RGB 图像,就像输入一样,但它给了我一个扭曲的图像like this,我不明白为什么。请帮我解决这个问题。

我修改为通过 resnet50 处理两个 RGB 输入的代码部分在这里 -

    for i, block in enumerate(self.down_blocks, 2):  # for all the down blocks 
        x = block(x)
        if i == (UNetWithResnet50Encoder.DEPTH - 1):
            continue
        pre_pools[f"layer_{i}"] = x  ## creating all the down sampling layers

    pre_pools_inp2 = dict()
    pre_pools_inp2[f"layer_0"] = y
    y = self.input_block(y)  #
    pre_pools_inp2[f"layer_1"] = y 
    y = self.input_pool(y)   

    for i, block in enumerate(self.down_blocks, 2):  # for all the down blocks 
        y = block(y)
        if i == (UNetWithResnet50Encoder.DEPTH - 1):
            continue
        pre_pools_inp2[f"layer_{i}"] = y  ## creating all the down sampling layers
    x = torch.cat([x,y],1)
    x = self.bridge(x)  # this is now the bridge between down sampling and up sampling 

    for i, block in enumerate(self.up_blocks, 1):
        key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"  # now using that bridge for upsampling f
        x = block(x, pre_pools[key])
    output_feature_map = x
    x = self.out(x)
    del pre_pools
    if with_output_feature_map:
        return x, output_feature_map
    else:
        return x
deep-learning computer-vision conv-neural-network image-segmentation unet-neural-network
1个回答
0
投票

不确定您的架构,因为您没有发布它。然而,这是我用于扩散模型的 UNet 架构,它适用于任何图像尺寸(矩形/正方形)。

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import math
import warnings
import os
import sys
import getpass
from collections.abc import Iterable
from typing import List, Union, Callable

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['MLIR_CRASH_REPRODUCER_DIRECTORY'] = '1'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
tf.get_logger().setLevel('ERROR')
warnings.filterwarnings('ignore')
os.system('clear')
tf.config.run_functions_eagerly(True)

class ConvBlock(tf.keras.Model):
    def __init__(self, filters = [64, 64], padding = 'valid', **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.padding = padding
        self.build()

    def build(self, input_dim = None):
        self.convs = [
                        layers.Conv2D(filters = f, kernel_size = 3, strides = 1, 
                                      padding = self.padding, activation = tf.keras.activations.swish,
                                      kernel_initializer = tf.keras.initializers.GlorotNormal(), 
                                      name = f"{self.name}_conv_{i+1}")
                        for i, f in enumerate(self.filters)
                     ]
        self.batch_norms = [
                                layers.BatchNormalization(name = f"{self.name}_batch_norm_{i+1}") 
                                for i, _ in enumerate(self.filters)
                           ]

    def call(self, inp, training = None):
        y = tf.identity(inp)
        for conv, bn in zip(self.convs, self.batch_norms):
            y = conv(y)
            y = bn(y)

        return y

class SinusoidalEmbedding(layers.Layer):
    def __init__(self, embedding_min_frequency = 1.0, 
                       embedding_max_frequency = 1000.0, 
                       embedding_dims = 32, 
                       **kwargs):
        
        super().__init__(**kwargs)
        assert embedding_dims % 2 == 0, f"[SinusoidalEmbedding] `embedding_dims` must be even. Found {embedding_dims}"
        self.embedding_min_frequency = embedding_min_frequency
        self.embedding_max_frequency = embedding_max_frequency
        self.embedding_dims = embedding_dims

    def call(self, x, training = None):
        embedding_min_frequency = 1.0
        frequencies = tf.math.exp(
                                    tf.linspace(tf.math.log(self.embedding_min_frequency), 
                                                tf.math.log(self.embedding_max_frequency), self.embedding_dims // 2)
                                )
        angular_speeds = tf.cast(2.0 * math.pi * frequencies, tf.float32)
        embeddings = tf.concat([tf.math.sin(angular_speeds * x), tf.math.cos(angular_speeds * x)], axis = -1)
        return embeddings


class DenoisingUNet(tf.keras.Model):
    def __init__(self, input_dim: Iterable = (28, 28),
                       num_channels: int = 1, 
                       depth: int = 5, 
                       filters: List[int] = [64, 128, 256, 512, 1024],
                       last_layer_activation: Callable = tf.identity,
                       add_sinusoidal_embedding = True,
                       embedding_min_frequency = None,
                       embedding_max_frequency = None,
                       embedding_dims = None,
                       **kwargs):
        super().__init__(**kwargs)
        assert len(input_dim) == 2, f"input_dim must be of the form (img_height, img_width). Found input_dim = {input_dim}"
        assert len(filters) == depth, f"Number of filters must be same as depth. Found filters = {filters}, depth = {depth}"
        self.input_dim = input_dim
        self.num_channels = num_channels
        self.depth = depth
        self.filters = filters
        self.add_sinusoidal_embedding = add_sinusoidal_embedding
        if add_sinusoidal_embedding:
            assert embedding_min_frequency is not None, "please provide embedding_min_frequency"
            assert embedding_max_frequency is not None, "please provide embedding_max_frequency"
            assert embedding_dims is not None, "please provide embedding_dims"
            self.embedding_min_frequency = embedding_min_frequency
            self.embedding_max_frequency = embedding_max_frequency
            self.embedding_dims = embedding_dims
        
        self.kernel_initializer = tf.keras.initializers.GlorotNormal()
        self.built = False

    def build(self, input_dim = None):

        up_path_paddings = ['valid' for _ in range(self.depth - 2)]
        up_path_paddings.append('same')

        if self.add_sinusoidal_embedding:
            self.sin_embed = SinusoidalEmbedding(embedding_min_frequency = self.embedding_min_frequency,
                                                 embedding_max_frequency = self.embedding_max_frequency,
                                                 embedding_dims = self.embedding_dims)

        self.context_reshaper = tf.keras.Sequential([
                                                    layers.Dense(self.input_dim[0], activation = tf.nn.tanh),
                                                    layers.Lambda(lambda x: tf.expand_dims(x, -1)),
                                                    layers.Dense(self.input_dim[1], activation = tf.nn.tanh),
                                                    layers.Lambda(lambda x: tf.expand_dims(x, -1))
                                                ], name = 'context_reshaper')

        self.context_reshaper.trainable = False

        self.down_path_convs = [
                                ConvBlock(filters = [f, f], name = f'conv_block_{i+1}') 
                                for i, f in enumerate(self.filters)
                               ]
        
        self.poolers = [
                        layers.MaxPooling2D(pool_size = 2, name = f"max_pool_{i+1}")
                        for i in range(self.depth - 1)
                       ]
        
        self.upsamplers = [
                            layers.UpSampling2D(size = 2, name = f"upsampling_{i+1}")
                            for i in range(self.depth - 1)
                          ]
        
        rev_filters = list(reversed(self.filters))
        self.up_path_convs = [
                                ConvBlock(filters = [rev_filters[i], rev_filters[i+1]],
                                          padding = padding,
                                          name = f"up_path_conv_block_{i+1}")
                                for i, padding in enumerate(up_path_paddings)
                             ]

        self.concats = [
                        layers.Concatenate(name = f"concat_{i+1}")
                        for i in range(self.depth - 1)
                        ]

        self.de_conv = tf.keras.Sequential(
                                        [
                                            layers.Conv2DTranspose(filters = self.filters[0],
                                                                   kernel_size = 3,
                                                                   strides = 1,
                                                                   padding = 'valid',
                                                                   activation = tf.keras.activations.swish,
                                                                   kernel_initializer = self.kernel_initializer),
                                            layers.BatchNormalization(),
                                            layers.Conv2DTranspose(filters = self.filters[0],
                                                                   kernel_size = 3,
                                                                   strides = 1,
                                                                   padding = 'valid',
                                                                   activation = tf.keras.activations.swish,
                                                                   kernel_initializer = self.kernel_initializer),
                                            layers.BatchNormalization(),
                                            layers.Conv2DTranspose(filters = self.num_channels,
                                                                   kernel_size = 1,
                                                                   strides = 1,
                                                                   padding = 'valid',
                                                                   activation = tf.identity,
                                                                   kernel_initializer = self.kernel_initializer),
                                        ], name = 'DeConvBlock'
                                        )

        self.built = True

    @tf.function
    def __call__(self, inp, embed_inp = None, context = None, training = None):     
        if not self.built:
            self.build()

        inp = tf.cast(inp, tf.float32)
        embed_inp = tf.cast(embed_inp, tf.float32)
        if context is not None:
            context = tf.cast(context, tf.float32)

        if self.add_sinusoidal_embedding:
            assert embed_inp is not None, "Provide an `embed_inp` to be embedded sinusoidally"
            embed_inp = self.sin_embed(embed_inp, training = training)
            embed_inp = layers.UpSampling2D(size = self.input_dim)(embed_inp)

        if False:
            # assert context.shape[1:] == (), f"context should be a scalar of shape (batch_size, ). Found context.shape = {context.shape}"
            # context = context[..., tf.newaxis, tf.newaxis, tf.newaxis]
            context = self.context_reshaper(context)
            x = tf.concat([inp, embed_inp, context], axis = -1)
        else:
            self.context_reshaper = None
            x = tf.concat([inp, embed_inp], axis = -1)
        
        skips = []
        for pool, conv in zip(self.poolers, self.down_path_convs):
            x = conv(x)
            skips.append(x)
            x = pool(x)

        x = self.down_path_convs[-1](x)

        for upsample, concat, conv in zip(self.upsamplers, self.concats, self.up_path_convs):
            x = upsample(x)
            y = skips.pop()
            x =  tf.image.resize(images = x, size = y.shape[1:3])
            x = concat([x, y])
            x = conv(x)

        op = self.de_conv(x)

        return op

    @tf.function
    def call(self, inp, embed_inp = None, context = None, training = None):
        return self.__call__(inp = inp, embed_inp = embed_inp, context = context, training = training)

    
    @tf.function
    def create_weights(self, add_context = True):
        inp = tf.random.normal((1, *self.input_dim, self.num_channels))
        if self.add_sinusoidal_embedding:
            embed_inp = tf.random.normal((1, 1, 1, 1))
        else:
            embed_inp = None

        if add_context:
            context = tf.random.normal((1, 10))
        else:
            context = None
        op = self(inp = inp, embed_inp = embed_inp, context = context)
        return True


    def summary(self, inp_shape: List[Iterable]):
        assert len(inp_shape) <= 3, f"At most 3 inputs are taken into DenoisingUNet"
        if len(inp_shape) == 1:
            inp_layer = layers.Input((*inp_shape[0], self.num_channels))
            inputs = [inp_layer]
        elif len(inp_shape) == 2:
            inp_layer = layers.Input((*inp_shape[0], self.num_channels))
            embed_inp_layer = layers.Input(inp_shape[1])
            inputs = [inp_layer, embed_inp_layer]
        else:
            inp_layer = layers.Input((*inp_shape[0], self.num_channels))
            embed_inp_layer = layers.Input(inp_shape[1])
            context_inp_layer = layers.Input(inp_shape[2])
            inputs = [inp_layer, embed_inp_layer, context_inp_layer]

        op = self.call(*inputs)
        model = tf.keras.Model(inputs = inputs, outputs = op, name = '2D-DenoisingUNet')
        model.summary()

    def plot(self, inp_shape):
        assert len(inp_shape) <= 3, f"At most 3 inputs are taken into DenoisingUNet"
        if len(inp_shape) == 1:
            inp_layer = layers.Input((*inp_shape[0], self.num_channels))
            inputs = [inp_layer]
        elif len(inp_shape) == 2:
            inp_layer = layers.Input((*inp_shape[0], self.num_channels))
            embed_inp_layer = layers.Input(inp_shape[1])
            inputs = [inp_layer, embed_inp_layer]
        else:
            inp_layer = layers.Input((*inp_shape[0], self.num_channels))
            embed_inp_layer = layers.Input(inp_shape[1])
            context_inp_layer = layers.Input(inp_shape[2])
            inputs = [inp_layer, embed_inp_layer, context_inp_layer]

        op = self.call(*inputs)
        model = tf.keras.Model(inputs = inputs, outputs = op, name = '2D-DenoisingUNet')
        tf.keras.utils.plot_model(model, show_shapes = True, to_file = '2DDenoisingUNet.png')


if __name__ == '__main__':
    unet = DenoisingUNet(input_dim = (64, 64),
                         num_channels = 3,
                         depth = 3,
                         filters = [64, 128, 256],
                         add_sinusoidal_embedding = True,
                         embedding_min_frequency = 1.0,
                         embedding_max_frequency = 1000.0,
                         embedding_dims = 32)
    unet.plot(inp_shape = [(64, 64), (1, 1, 1), (64, )])
    unet.summary(inp_shape = [(64, 64), (1, 1, 1), (64, )])

上面的代码是图像尺寸为(64,64,3)的示例。我尝试了图像大小为 (224, 224, 3) (ImageNet) 和 (178, 218, 3) (CelebA) 的图像,效果很好。

© www.soinside.com 2019 - 2024. All rights reserved.