级联卷积神经网络 - 使用 TensorFlow API 的多输入和多输出

问题描述 投票:0回答:1

我正在尝试实施 this 论文中提出的级联模型,但在数据加载管道中遇到了一些问题。该模型的总体架构如下所示:

数据子集是:road: (img, mask_road) and centerline: (same img, mask_centerline)。为了构建数据加载管道,我使用

tf.data.Dataset
API 创建了这些输入,即读取数据、解码和转换为张量 [0,1]。因此,为了训练模型,我尝试压缩我的输入,例如:

zip_train = tf.data.Dataset.zip((dataset_train,center_train))
zip_valid = tf.data.Dataset.zip((dataset_val,center_val))
zip_test = tf.data.Dataset.zip((dataset_test,center_test))

在上面的模型图中,对于中心线提取(网络 2),输入是来自网络 1(道路检测)的最后一个特征图和相应的子集。因此,我在模型定义期间将它们连接起来(请参阅下面的代码附加要点文件)。当我尝试运行代码时,错误表明我的连接层不兼容,特别是关于通道,如下所示:

history = model.fit(
      zip_train, 
      epochs=epochs, 
      steps_per_epoch=steps, 
      validation_data=zip_valid, 
      callbacks=callbacks
  )

Input 0 of layer "d1_11_conv" is incompatible with the layer:
expected axis -1 of input shape to have value 67, 
but received input with shape (None, 512, 512, 65)

Call arguments received by layer 'model_12' (type Functional):

  • inputs=('tf.Tensor(shape=(None, 512, 512, 3), dtype=float32)',
   'tf.Tensor(shape=(None, 512, 512, 1), dtype=float32)')
  • training=True
  • mask=None

如何解决?这是包含级联模型定义和数据加载管道的可重现代码

tensorflow keras deep-learning tensorflow-datasets multitasking
1个回答
0
投票

我们先总结一下全貌。您尝试构建和运行的模型是一种双自动编码器模型,旨在同时解决两个任务。因此,如果我们传递一个输入图像,模型将给出两个输出,即道路图和中心线图。但首先,我们需要用给定的数据集试验这个模型,其中存在图像和相应的道路和中心线分割掩模。简而言之,我们可以将这个问题定义为具有 1 个输入和 2 个输出的语义分割。

为此类任务使用

tf.data
API 构建训练数据加载器非常简单。但是,可以有不同的方法,但总体设置是相同的。关于您在连接层时遇到的错误,我认为这是预料之中的。但是根据论文中的数字,我认为您不需要那样做。您可以简单地将第一个网络的特征图传递给下一个网络。让我们一步步构建这个项目。我正在使用 TF 2.11,使用 P100 GPU 在 kaggle 上进行测试。

型号

一些常见的图层块。

def ConvBlock(filters, kernel, kernel_initializer, activation, name=None):
    
    if name is None:
        name = "ConvBlock" + str(backend.get_uid("ConvBlock"))
    
    def apply(input):
        c1 = layers.Conv2D(
            filters=filters,
            kernel_size=kernel,
            padding='same',
            kernel_initializer=kernel_initializer,
            name=name+'_conv'
        )(input)
        c1 = layers.BatchNormalization(name=name+'_batch')(c1)
        c1 = layers.Activation(activation,name=name+'_active')(c1)
        return c1
    
    return apply
def DownConvBlock(filters, kernel, kernel_initializer, activation, name=None):
    
    if name is None:
        name = "DownConvBlock" + str(backend.get_uid("DownConvBlock"))
    
    def apply(input):
        d1 = layers.Conv2DTranspose(
            filters=filters,
            kernel_size=kernel,
            padding='same',
            kernel_initializer=kernel_initializer,
            name=name+'_conv'
        )(input)
        d1 = layers.BatchNormalization(name=name+'_batch')(d1)
        d1 = layers.Activation(activation,name=name+'_active')(d1)
        return d1
    
    return apply

道路面具检测任务的子模型。

def network_mask(input, activation, kernel_initializer, kernel_size):
    # Network 1
    # ENCODER
    x = input
    for fmap in [64, 128, 256, 512]:
        x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
        x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
        x = layers.MaxPool2D(pool_size=(2,2), strides=None, padding='same')(x)

    # DECODER   
    for fmap in [512, 256, 128, 64]:
        x = layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
        x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
        x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)

    x = layers.Conv2D(
            filters=1, 
            kernel_size=(1,1),
            kernel_initializer=kernel_initializer,
            activation=None,
    )(x)
    
    return x

中心线面具检测任务的子模型。

def network_centerline(input, activation, kernel_initializer, kernel_size):
    # Network 2
    # ENCODER
    x = input
    for fmap in [64, 128, 256]:
        x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
        x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
        x = layers.MaxPool2D(pool_size=(2,2), strides=None, padding='same')(x)

    # DECODER   
    for fmap in [256, 128, 64]:
        x = layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
        x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
        x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
        
    x = layers.Conv2DTranspose(
        filters=1, 
        kernel_size=(1,1), 
        kernel_initializer=kernel_initializer,
        activation=None, 
    )(x)
    
    return x

全级联网络,即CasNet。

def CasNet(activation, kernel_initializer, kernel_size):
    input = keras.Input(shape=(img_size, img_size, channel), name='images')
    
    mask_feat = network_mask(input, activation, kernel_initializer, kernel_size)
    centerline_feat = network_centerline(
        mask_feat, activation, kernel_initializer, kernel_size
    )
    
    mask_op = keras.layers.Activation(
        'sigmoid', name='mask', dtype=tf.float32
    )(mask_feat)
    centerline_op = keras.layers.Activation(
        'sigmoid', name='centerline', dtype=tf.float32
    )(centerline_feat)
    
    model = keras.Model(
        inputs={
            'images': input
        },
        outputs={
            'mask': mask_op,
            'centerline': centerline_op
        },
        name='CasNet'
    )
    return model

数据加载器

keras 中的增强管道。在未来的日子里,我们可以为此使用

keras-cv

set_seed = 101
rand_flip = layers.RandomFlip("horizontal_and_vertical", seed=set_seed)
rand_rote = layers.RandomRotation(factor=0.01, seed=set_seed)
# more: https://keras.io/api/layers/preprocessing_layers/image_augmentation/

def keras_augment(image, label, centerline):
    tensors =  tf.concat([image, label, centerline], axis=-1)
    
    def apply_augment(x):
        x = rand_flip(x)
        x = rand_rote(x)
        return x
    
    aug_tensors = apply_augment(tensors)
    image, label, centerline = tf.split(aug_tensors, [3, 1, 1], axis=-1)
    return image, label, centerline

加载样本(道路、面罩、中心线)。道路图像的像素值是正常的RGB颜色,范围从

0~255
。并且道路遮罩和道路中心线的像素值介于
0-255
之间,具有3个颜色通道。我们将标准化这些值。

def read_files(image_path, mask=False):
    image = tf.io.read_file(image_path)
    if mask:
        image = tf.io.decode_png(image, channels=1, dtype=tf.uint8)
        image = tf.image.resize(
            images=image, 
            size=[img_size, img_size], 
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
        )
        image = tf.where(image == 255, 1, 0)
        image = tf.cast(image, tf.float32)
    else:
        image = tf.io.decode_png(image, channels=3, dtype=tf.uint8)
        image = tf.image.resize(images=image, size=[img_size, img_size])
        image = tf.cast(image, tf.float32)
        image = image / 255.
    return image

def load_data(image_list, label_list, centerline_list): 
    image = read_files(image_list)
    label = read_files(label_list, mask=True)
    center = read_files(centerline_list, mask=True)
    return image, label, center

注意这里,我们如何打包(下面的

prepare_dict
方法)单输入和多输出的数据。对于多输入和多输出或多输入和单输出等,可以做同样的事情。同样,如前所述,可以有不同的方式使用相同的 API 加载此类数据集,但整体设置是相同的。为了避免混淆,我不想提及可能的替代方案。

def prepare_dict(image_batch, label_batch, centerline_batch):
    return {'images': image_batch}, {'mask':label_batch, 'centerline':centerline_batch}

def dataloader(image_list, label_list, center_list, split='train'):
    dataset = tf.data.Dataset.from_tensor_slices(
        (image_list, label_list, center_list)
    )
    dataset = dataset.shuffle(batch_size * 8) if split == 'train' else dataset
    dataset = dataset.repeat() if split == 'train' else dataset
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.map(keras_augment) if split == 'train' else dataset
    dataset = dataset.map(prepare_dict, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    return dataset

download

download

编译运行

让我们用损失和指标编译模型并拟合它。对于损失和指标,我们将使用 this 库,直到 keras-cv 准备好执行分割任务。请参阅下面的

loss
metrics
参数,我们正在为模型的两个输出传递损失和度量函数。虽然我们可以简单地传递一个损失/度量方法,并将用于两个输出,但很高兴知道我们可以以这种格式传递损失/度量方法。

model.compile(
    optimizer=keras.optimizers.Adam(
        learning_rate=0.0001
    ),
    loss={
        'mask':sm.losses.bce_jaccard_loss,
        'centerline': sm.losses.binary_focal_jaccard_loss
    },
    metrics={
        'mask': sm.metrics.iou_score,
        'centerline': sm.metrics.f1_score
    }
)

history = model.fit(
    train_ds, 
    validation_data=valid_ds,
    steps_per_epoch=len(train_images_path) // batch_size,
    callbacks=my_callbacks,
    epochs=epoch
)
...
...
160/160 [==============================] - 186s
loss: 1.0082 - centerline_loss: 0.7613 - mask_loss: 0.2469 -
centerline_f1-score: 0.4074 - mask_iou_score: 0.8115 - 
val_loss: 1.2867 - val_centerline_loss: 0.7986 - 
val_mask_loss: 0.4882 - val_centerline_f1-score: 0.3572 - 
val_mask_iou_score: 0.6860

160/160 [==============================] - 186s 1s/step - 
loss: 0.9827 - centerline_loss: 0.7491 - mask_loss: 0.2336 - 
centerline_f1-score: 0.4223 - mask_iou_score: 0.8210 - 
val_loss: 1.4251 - val_centerline_loss: 0.8222 - 
val_mask_loss: 0.6028 - val_centerline_f1-score: 0.3160 - 
val_mask_iou_score: 0.6344
...
...

download

download

download

完整代码和资源

这里是完整代码,它在 kaggle (P100, TF 2.11) 上运行。这里有一些可能会派上用场的资源。其中大部分与分割建模和损失方法选择有关。

© www.soinside.com 2019 - 2024. All rights reserved.