嗨,我需要一些关于 Keras 中自定义损失函数的帮助。我基本上是在构建一个带有第二个输入的 UNET,它采用原始 UNET 论文中的权重图。 然而,我正在使用这个 UNET 进行图像合成,我的损失函数是使用三个输入(输入图像、重建图像和权重图)计算的感知损失和像素损失的组合。 UNET 模型是一个带有编码器、解码器和跳过连接的标准 UNET。
下面是我的网络代码和损失函数:
def synthesis_unet_weights(pretrained_weights=None, input_shape=(SIZE_s, SIZE_s, 3), num_classes=1, is_training=True):
ip = Input(shape=input_shape)
weight_ip = Input(shape=input_shape[:2] + (num_classes,))
UNET encoder with the first Conv2D layer taking input ip
#---------------------------------------------------------------------------------------------------------------------------
center = Conv2D(1024, (3,3),padding='same', activation='relu', kernel_initializer=initializer)(pool4)
center = Conv2D(1024, (3,3),padding='same', activation='relu', kernel_initializer=initializer)(center)
#---------------------------------------------------------------------------------------------------------------------------
UNET decoder with the last layer up1
classify = Conv2D(num_classes, (1,1), activation='sigmoid')(up1)
if is_training:
model=Model(inputs=[ip, weight_ip], outputs=[classify])
model.add_loss(perceptual_loss_weight(ip,classify,weight_ip))
return model
else:
model = Model(inputs=[ip], outputs=[classify])
weight_ip=ip
model.add_loss(perceptual_loss_weight(ip,classify,weight_ip))
opt2 = tf.keras.optimizers.Adam(learning_rate=1e-3,clipnorm=1.0)
model.compile(optimizer=opt2)
return model
return model
def perceptual_loss_weight(input_image , reconstruct_image, weights):
input_image = clip_0_1(input_image)
reconstruct_image = tf.concat((reconstruct_image,reconstruct_image,reconstruct_image),axis=-1)
reconstruct_image = clip_0_1(reconstruct_image)
weights = tf.concat((weights,weights,weights),axis=-1)
weights = clip_0_1(weights)
h1_list = LossModel(input_image)
h2_list = LossModel(reconstruct_image)
rc_loss = 0.0
for h1, h2, weight in zip(h1_list, h2_list, selected_layer_weights):
h1 = K.batch_flatten(h1)
h2 = K.batch_flatten(h2)
rc_loss = rc_loss + weight * K.sum(K.square(h1 - h2), axis=-1)
pixel_loss = K.sum(K.square(K.batch_flatten(weights)*K.batch_flatten(input_image) - K.batch_flatten(weights)*K.batch_flatten(reconstruct_image)),axis=1)
return rc_loss+pixel_loss
权重输入仅用于训练时的损失函数。我设法训练了模型(使用 loss=None 进行编译),但它没有预测它应该预测的内容。看起来输入只是通过网络(没有任何修改)直接传递到输出。重建的输出图像看起来与输入图像完全一样。