使用的完整代码:
import os
import numpy as np
from keras_unet_collection import models, losses
from tensorflow import keras
from PIL import Image
def hybrid_loss(y_true, y_pred):
loss_focal = losses.focal_tversky(y_true, y_pred, alpha=0.5, gamma=4 / 3)
loss_iou = losses.iou_seg(y_true, y_pred)
return loss_focal + loss_iou
model = models.unet_3plus_2d((128, 128, 3), n_labels=2, filter_num_down=[64, 128, 256, 512],
filter_num_skip=[64, 64, 64], filter_num_aggregate=256,
stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid',
batch_norm=True, pool='max', unpool=False, deep_supervision=True, name='unet3plus')
model.compile(loss=[hybrid_loss, hybrid_loss, hybrid_loss, hybrid_loss, hybrid_loss],
loss_weights=[0.25, 0.25, 0.25, 0.25, 1.0],
optimizer=keras.optimizers.Adam(learning_rate=1e-4))
def load_images(path, target_size=(128, 128), grayscale=False):
images = []
for filename in sorted(os.listdir(path)):
image = Image.open(os.path.join(path, filename))
if grayscale:
image = image.convert('L')
else:
image = image.convert('RGB')
image = image.resize(target_size, resample=Image.BICUBIC)
images.append(np.array(image))
return np.array(images)
train_images_path = 'data/train/images'
train_masks_path = 'data/train/mask'
val_images_path = 'data/val/images'
val_masks_path = 'data/val/mask'
image_size = (128, 128)
# Load the images and masks as numpy arrays
train_images = load_images(train_images_path, target_size=image_size)
train_masks = load_images(train_masks_path, target_size=image_size, grayscale=True)
val_images = load_images(val_images_path, target_size=image_size)
val_masks = load_images(val_masks_path, target_size=image_size, grayscale=True)
history = model.fit(train_images, train_masks, batch_size=8, epochs=50, validation_data=(val_images, val_masks))
错误:
Node: 'gradient_tape/hybrid_loss_2/mul/BroadcastGradientArgs'
Incompatible shapes: [131072] vs. [393216]
[[{{node gradient_tape/hybrid_loss_2/mul/BroadcastGradientArgs}}]] [Op:__inference_train_function_14372]
我必须尽我最大的能力遵循 keras-unet-collection 模块文档中给出的示例,但我正在努力让它真正起作用。我正在使用我自己的图像加载功能,因为我的蒙版图像在加载之前不是灰度的。我的直接研究表明,在 hybrid_loss 函数中 y_true 和 y_pred 的形状不匹配,但是当我从 hybrid_loss 函数打印一些东西时,它在失败之前成功输出了多次。
显然,我对 tensorflow 等没有太多经验,所以如果你能帮我弄清楚是什么导致了这个问题,以及我如何解决它,那将非常有帮助