自定义数据集上的 keras-unet-collection 训练导致不兼容的形状错误

问题描述 投票:0回答:0

使用的完整代码:

import os
import numpy as np
from keras_unet_collection import models, losses
from tensorflow import keras
from PIL import Image


def hybrid_loss(y_true, y_pred):
    loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.5, gamma=4 / 3)
    loss_iou = losses.iou_seg(y_true, y_pred)
    return loss_focal + loss_iou


model = models.unet_3plus_2d((128, 128, 3), n_labels=2, filter_num_down=[64, 128, 256, 512],
                             filter_num_skip=[64, 64, 64], filter_num_aggregate=256,
                             stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid',
                             batch_norm=True, pool='max', unpool=False, deep_supervision=True, name='unet3plus')

model.compile(loss=[hybrid_loss, hybrid_loss, hybrid_loss, hybrid_loss, hybrid_loss],
                  loss_weights=[0.25, 0.25, 0.25, 0.25, 1.0],
                  optimizer=keras.optimizers.Adam(learning_rate=1e-4))


def load_images(path, target_size=(128, 128), grayscale=False):
    images = []
    for filename in sorted(os.listdir(path)):
        image = Image.open(os.path.join(path, filename))
        if grayscale:
            image = image.convert('L')
        else:
            image = image.convert('RGB')
        image = image.resize(target_size, resample=Image.BICUBIC)
        images.append(np.array(image))
    return np.array(images)


train_images_path = 'data/train/images'
train_masks_path = 'data/train/mask'
val_images_path = 'data/val/images'
val_masks_path = 'data/val/mask'

image_size = (128, 128)

# Load the images and masks as numpy arrays
train_images = load_images(train_images_path, target_size=image_size)
train_masks = load_images(train_masks_path, target_size=image_size, grayscale=True)
val_images = load_images(val_images_path, target_size=image_size)
val_masks = load_images(val_masks_path, target_size=image_size, grayscale=True)

history = model.fit(train_images, train_masks, batch_size=8, epochs=50, validation_data=(val_images, val_masks))

错误:

Node: 'gradient_tape/hybrid_loss_2/mul/BroadcastGradientArgs'
Incompatible shapes: [131072] vs. [393216]
     [[{{node gradient_tape/hybrid_loss_2/mul/BroadcastGradientArgs}}]] [Op:__inference_train_function_14372]

我必须尽我最大的能力遵循 keras-unet-collection 模块文档中给出的示例,但我正在努力让它真正起作用。我正在使用我自己的图像加载功能,因为我的蒙版图像在加载之前不是灰度的。我的直接研究表明,在 hybrid_loss 函数中 y_true 和 y_pred 的形状不匹配,但是当我从 hybrid_loss 函数打印一些东西时,它在失败之前成功输出了多次。

显然,我对 tensorflow 等没有太多经验,所以如果你能帮我弄清楚是什么导致了这个问题,以及我如何解决它,那将非常有帮助

python tensorflow keras image-segmentation unet-neural-network
© www.soinside.com 2019 - 2024. All rights reserved.