我是计算机视觉和深度学习的新手。我正在尝试使用 Resnet50 作为编码器来训练这个 Unet 模型https://github.com/kevinlu1211/pytorch-unet-resnet-50-encoder。我想以传递两个 RGB 图像的方式来实现它,这些图像首先由 resnet50 处理,然后在传递给解码器之前将各层连接起来。我尝试这样做,并将代码中的 n_classes 更改为 3 以输出 3 通道 RGB 图像,就像输入一样,但它给了我一个扭曲的图像,我不明白为什么。请帮我解决这个问题。
我修改为通过 resnet50 处理两个 RGB 输入的代码部分在这里 -
for i, block in enumerate(self.down_blocks, 2): # for all the down blocks
x = block(x)
if i == (UNetWithResnet50Encoder.DEPTH - 1):
continue
pre_pools[f"layer_{i}"] = x ## creating all the down sampling layers
pre_pools_inp2 = dict()
pre_pools_inp2[f"layer_0"] = y
y = self.input_block(y) #
pre_pools_inp2[f"layer_1"] = y
y = self.input_pool(y)
for i, block in enumerate(self.down_blocks, 2): # for all the down blocks
y = block(y)
if i == (UNetWithResnet50Encoder.DEPTH - 1):
continue
pre_pools_inp2[f"layer_{i}"] = y ## creating all the down sampling layers
x = torch.cat([x,y],1)
x = self.bridge(x) # this is now the bridge between down sampling and up sampling
for i, block in enumerate(self.up_blocks, 1):
key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" # now using that bridge for upsampling f
x = block(x, pre_pools[key])
output_feature_map = x
x = self.out(x)
del pre_pools
if with_output_feature_map:
return x, output_feature_map
else:
return x
不确定您的架构,因为您没有发布它。然而,这是我用于扩散模型的 UNet 架构,它适用于任何图像尺寸(矩形/正方形)。
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import math
import warnings
import os
import sys
import getpass
from collections.abc import Iterable
from typing import List, Union, Callable
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['MLIR_CRASH_REPRODUCER_DIRECTORY'] = '1'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
tf.get_logger().setLevel('ERROR')
warnings.filterwarnings('ignore')
os.system('clear')
tf.config.run_functions_eagerly(True)
class ConvBlock(tf.keras.Model):
def __init__(self, filters = [64, 64], padding = 'valid', **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.padding = padding
self.build()
def build(self, input_dim = None):
self.convs = [
layers.Conv2D(filters = f, kernel_size = 3, strides = 1,
padding = self.padding, activation = tf.keras.activations.swish,
kernel_initializer = tf.keras.initializers.GlorotNormal(),
name = f"{self.name}_conv_{i+1}")
for i, f in enumerate(self.filters)
]
self.batch_norms = [
layers.BatchNormalization(name = f"{self.name}_batch_norm_{i+1}")
for i, _ in enumerate(self.filters)
]
def call(self, inp, training = None):
y = tf.identity(inp)
for conv, bn in zip(self.convs, self.batch_norms):
y = conv(y)
y = bn(y)
return y
class SinusoidalEmbedding(layers.Layer):
def __init__(self, embedding_min_frequency = 1.0,
embedding_max_frequency = 1000.0,
embedding_dims = 32,
**kwargs):
super().__init__(**kwargs)
assert embedding_dims % 2 == 0, f"[SinusoidalEmbedding] `embedding_dims` must be even. Found {embedding_dims}"
self.embedding_min_frequency = embedding_min_frequency
self.embedding_max_frequency = embedding_max_frequency
self.embedding_dims = embedding_dims
def call(self, x, training = None):
embedding_min_frequency = 1.0
frequencies = tf.math.exp(
tf.linspace(tf.math.log(self.embedding_min_frequency),
tf.math.log(self.embedding_max_frequency), self.embedding_dims // 2)
)
angular_speeds = tf.cast(2.0 * math.pi * frequencies, tf.float32)
embeddings = tf.concat([tf.math.sin(angular_speeds * x), tf.math.cos(angular_speeds * x)], axis = -1)
return embeddings
class DenoisingUNet(tf.keras.Model):
def __init__(self, input_dim: Iterable = (28, 28),
num_channels: int = 1,
depth: int = 5,
filters: List[int] = [64, 128, 256, 512, 1024],
last_layer_activation: Callable = tf.identity,
add_sinusoidal_embedding = True,
embedding_min_frequency = None,
embedding_max_frequency = None,
embedding_dims = None,
**kwargs):
super().__init__(**kwargs)
assert len(input_dim) == 2, f"input_dim must be of the form (img_height, img_width). Found input_dim = {input_dim}"
assert len(filters) == depth, f"Number of filters must be same as depth. Found filters = {filters}, depth = {depth}"
self.input_dim = input_dim
self.num_channels = num_channels
self.depth = depth
self.filters = filters
self.add_sinusoidal_embedding = add_sinusoidal_embedding
if add_sinusoidal_embedding:
assert embedding_min_frequency is not None, "please provide embedding_min_frequency"
assert embedding_max_frequency is not None, "please provide embedding_max_frequency"
assert embedding_dims is not None, "please provide embedding_dims"
self.embedding_min_frequency = embedding_min_frequency
self.embedding_max_frequency = embedding_max_frequency
self.embedding_dims = embedding_dims
self.kernel_initializer = tf.keras.initializers.GlorotNormal()
self.built = False
def build(self, input_dim = None):
up_path_paddings = ['valid' for _ in range(self.depth - 2)]
up_path_paddings.append('same')
if self.add_sinusoidal_embedding:
self.sin_embed = SinusoidalEmbedding(embedding_min_frequency = self.embedding_min_frequency,
embedding_max_frequency = self.embedding_max_frequency,
embedding_dims = self.embedding_dims)
self.context_reshaper = tf.keras.Sequential([
layers.Dense(self.input_dim[0], activation = tf.nn.tanh),
layers.Lambda(lambda x: tf.expand_dims(x, -1)),
layers.Dense(self.input_dim[1], activation = tf.nn.tanh),
layers.Lambda(lambda x: tf.expand_dims(x, -1))
], name = 'context_reshaper')
self.context_reshaper.trainable = False
self.down_path_convs = [
ConvBlock(filters = [f, f], name = f'conv_block_{i+1}')
for i, f in enumerate(self.filters)
]
self.poolers = [
layers.MaxPooling2D(pool_size = 2, name = f"max_pool_{i+1}")
for i in range(self.depth - 1)
]
self.upsamplers = [
layers.UpSampling2D(size = 2, name = f"upsampling_{i+1}")
for i in range(self.depth - 1)
]
rev_filters = list(reversed(self.filters))
self.up_path_convs = [
ConvBlock(filters = [rev_filters[i], rev_filters[i+1]],
padding = padding,
name = f"up_path_conv_block_{i+1}")
for i, padding in enumerate(up_path_paddings)
]
self.concats = [
layers.Concatenate(name = f"concat_{i+1}")
for i in range(self.depth - 1)
]
self.de_conv = tf.keras.Sequential(
[
layers.Conv2DTranspose(filters = self.filters[0],
kernel_size = 3,
strides = 1,
padding = 'valid',
activation = tf.keras.activations.swish,
kernel_initializer = self.kernel_initializer),
layers.BatchNormalization(),
layers.Conv2DTranspose(filters = self.filters[0],
kernel_size = 3,
strides = 1,
padding = 'valid',
activation = tf.keras.activations.swish,
kernel_initializer = self.kernel_initializer),
layers.BatchNormalization(),
layers.Conv2DTranspose(filters = self.num_channels,
kernel_size = 1,
strides = 1,
padding = 'valid',
activation = tf.identity,
kernel_initializer = self.kernel_initializer),
], name = 'DeConvBlock'
)
self.built = True
@tf.function
def __call__(self, inp, embed_inp = None, context = None, training = None):
if not self.built:
self.build()
inp = tf.cast(inp, tf.float32)
embed_inp = tf.cast(embed_inp, tf.float32)
if context is not None:
context = tf.cast(context, tf.float32)
if self.add_sinusoidal_embedding:
assert embed_inp is not None, "Provide an `embed_inp` to be embedded sinusoidally"
embed_inp = self.sin_embed(embed_inp, training = training)
embed_inp = layers.UpSampling2D(size = self.input_dim)(embed_inp)
if False:
# assert context.shape[1:] == (), f"context should be a scalar of shape (batch_size, ). Found context.shape = {context.shape}"
# context = context[..., tf.newaxis, tf.newaxis, tf.newaxis]
context = self.context_reshaper(context)
x = tf.concat([inp, embed_inp, context], axis = -1)
else:
self.context_reshaper = None
x = tf.concat([inp, embed_inp], axis = -1)
skips = []
for pool, conv in zip(self.poolers, self.down_path_convs):
x = conv(x)
skips.append(x)
x = pool(x)
x = self.down_path_convs[-1](x)
for upsample, concat, conv in zip(self.upsamplers, self.concats, self.up_path_convs):
x = upsample(x)
y = skips.pop()
x = tf.image.resize(images = x, size = y.shape[1:3])
x = concat([x, y])
x = conv(x)
op = self.de_conv(x)
return op
@tf.function
def call(self, inp, embed_inp = None, context = None, training = None):
return self.__call__(inp = inp, embed_inp = embed_inp, context = context, training = training)
@tf.function
def create_weights(self, add_context = True):
inp = tf.random.normal((1, *self.input_dim, self.num_channels))
if self.add_sinusoidal_embedding:
embed_inp = tf.random.normal((1, 1, 1, 1))
else:
embed_inp = None
if add_context:
context = tf.random.normal((1, 10))
else:
context = None
op = self(inp = inp, embed_inp = embed_inp, context = context)
return True
def summary(self, inp_shape: List[Iterable]):
assert len(inp_shape) <= 3, f"At most 3 inputs are taken into DenoisingUNet"
if len(inp_shape) == 1:
inp_layer = layers.Input((*inp_shape[0], self.num_channels))
inputs = [inp_layer]
elif len(inp_shape) == 2:
inp_layer = layers.Input((*inp_shape[0], self.num_channels))
embed_inp_layer = layers.Input(inp_shape[1])
inputs = [inp_layer, embed_inp_layer]
else:
inp_layer = layers.Input((*inp_shape[0], self.num_channels))
embed_inp_layer = layers.Input(inp_shape[1])
context_inp_layer = layers.Input(inp_shape[2])
inputs = [inp_layer, embed_inp_layer, context_inp_layer]
op = self.call(*inputs)
model = tf.keras.Model(inputs = inputs, outputs = op, name = '2D-DenoisingUNet')
model.summary()
def plot(self, inp_shape):
assert len(inp_shape) <= 3, f"At most 3 inputs are taken into DenoisingUNet"
if len(inp_shape) == 1:
inp_layer = layers.Input((*inp_shape[0], self.num_channels))
inputs = [inp_layer]
elif len(inp_shape) == 2:
inp_layer = layers.Input((*inp_shape[0], self.num_channels))
embed_inp_layer = layers.Input(inp_shape[1])
inputs = [inp_layer, embed_inp_layer]
else:
inp_layer = layers.Input((*inp_shape[0], self.num_channels))
embed_inp_layer = layers.Input(inp_shape[1])
context_inp_layer = layers.Input(inp_shape[2])
inputs = [inp_layer, embed_inp_layer, context_inp_layer]
op = self.call(*inputs)
model = tf.keras.Model(inputs = inputs, outputs = op, name = '2D-DenoisingUNet')
tf.keras.utils.plot_model(model, show_shapes = True, to_file = '2DDenoisingUNet.png')
if __name__ == '__main__':
unet = DenoisingUNet(input_dim = (64, 64),
num_channels = 3,
depth = 3,
filters = [64, 128, 256],
add_sinusoidal_embedding = True,
embedding_min_frequency = 1.0,
embedding_max_frequency = 1000.0,
embedding_dims = 32)
unet.plot(inp_shape = [(64, 64), (1, 1, 1), (64, )])
unet.summary(inp_shape = [(64, 64), (1, 1, 1), (64, )])
上面的代码是图像尺寸为(64,64,3)的示例。我尝试了图像大小为 (224, 224, 3) (ImageNet) 和 (178, 218, 3) (CelebA) 的图像,效果很好。