为什么我的 ESRGan python 代码会产生棋盘工件?

问题描述 投票:0回答:0

这是我的 ESRGan 代码并生成棋盘工件,但我不知道为什么:

def preprocess_vgg(x):
        """Take a HR image [-1, 1], convert to [0, 255], then to input for VGG network"""
        
        if isinstance(x, np.ndarray):
            return preprocess_input((x + 1) * 127.5)
        else:
            return Lambda(lambda x: preprocess_input(tf.add(x, 1) * 127.5))(x)
        
class VGG_LOSS(object):

    def __init__(self, image_shape):
        
        self.image_shape = image_shape
        self.VGG_i=2
        self.VGG_j=2
        self.vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=self.image_shape)
        self.vgg19.trainable = False
        self.before_act_output =[]

        # Make trainable as False
        for l in self.vgg19.layers:
            l.trainable = False
        
        #Capa block2_conv2 before activation
        '''self.block2_conv2_copy = Conv2D(filters=128, kernel_size=(3,       3),padding='same',name="block2_conv2") 
        self.block2_conv2_copy.trainable = False
        self.model = Sequential(self.vgg19.layers[:5] + [self.block2_conv2_copy])
        self.block2_conv2_copy.set_weights(self.vgg19.layers[5].get_weights())'''
        
        #Capa block5_conv4 before activation
        self.block5_conv4_copy = Conv2D(filters=512, kernel_size=(3, 3), padding='same',name="block5_conv4")
        self.block5_conv4_copy.trainable = False
        self.model = Sequential(self.vgg19.layers[:20] + [self.block5_conv4_copy])
        self.block5_conv4_copy.set_weights(self.vgg19.layers[20].get_weights())
        
        self.model.trainable = False
    
    
    # computes VGG loss or content loss
    def vgg_loss(self, y_true, y_pred):
        
        y_true = preprocess_vgg(y_true) ### preprocess_vgg
        y_pred = preprocess_vgg(y_pred) ### preprocess_vgg
        sr_features = self.model(y_pred) / 12.75  
        hr_features = self.model(y_true) / 12.75  
        return mean_squared_error(sr_features,hr_features)
    
    
def get_optimizer():
 
    adam = Adam(learning_rate=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    return adam

# Subpixel Conv will upsample from (h, w, c) to (h/r, w/r, c/r^2)
def SubpixelConv2D(scale=2):
        """
        Keras layer to do subpixel convolution.
        NOTE: Tensorflow backend only. Uses tf.depth_to_space
        :param scale: upsampling scale compared to input_shape. Default=2
        :return:
        """

        def subpixel_shape(input_shape):
            dims = [input_shape[0],
                    None if input_shape[1] is None else input_shape[1] * scale,
                    None if input_shape[2] is None else input_shape[2] * scale,
                    int(input_shape[3] / (scale ** 2))]
            output_shape = tuple(dims)
            return output_shape

        def subpixel(x):
            return tf.nn.depth_to_space(x, scale)

        return Lambda(subpixel, output_shape=subpixel_shape)


# Residual block
def residual_block_gen(inp, ch=64,k_s=3,n_blocks=4):
    concat=inp
    for x in range(n_blocks):
        out=tf.keras.layers.Conv2D(ch,k_s,padding='same')(concat)
        out=tf.keras.layers.PReLU(shared_axes=[1, 2])(out)
    
        concat=tf.keras.layers.concatenate([concat,out])

    out=tf.keras.layers.Conv2D(ch,k_s,padding='same')(concat)
    return out
    
    
def up_sampling_block(model, kernal_size, filters, strides):

    model = Conv2D(filters, kernel_size=3, strides=1, padding='same')(model) # Cambiado 5/04/2023
    model = SubpixelConv2D(2)(model)
    
    model = PReLU(shared_axes=[1, 2])(model)
    
    return model

# Network Architecture is same as given in Paper https://arxiv.org/pdf/1609.04802.pdf
class Generator(object):

    def __init__(self, noise_shape = 256):
        
        self.noise_shape = noise_shape

    def generator(self, num_filters=32, num_res_blocks=10, res_block_scaling=None):
        ## ESRResnet
        residual_scaling=0.2

        input_lr=Input(shape=self.noise_shape)
        input_conv=Conv2D(num_filters,9,padding='same')(input_lr) 
        input_conv=PReLU(shared_axes=[1, 2])(input_conv)

        ESRRes=input_conv
        for x in range(5):
            res_output=residual_block_gen(ESRRes,ch=num_filters)
            ESRRes=tf.keras.layers.Add()([ESRRes,res_output * residual_scaling])

        ESRRes=tf.keras.layers.Conv2D(num_filters,3,padding='same')(ESRRes)
        ESRRes=tf.keras.layers.BatchNormalization()(ESRRes)
        ESRRes=tf.keras.layers.Add()([ESRRes,input_conv])

        ESRRes=up_sampling_block(ESRRes, 3, 512, 1) 
        #ESRRes=Upsample_block(ESRRes)

        output_sr=Conv2D(3,9,activation='tanh',padding='same')(ESRRes)

        generator_model=Model(input_lr,output_sr)
        ## Fin ESRResnet
        
        return generator_model
    
def discriminator_block(model, filters, kernel_size=3, strides=1,batchnorm=True,momentum=0.8):
    x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(model)
    if batchnorm:
        x = BatchNormalization(momentum=momentum)(x)
    return LeakyReLU(alpha=0.2)(x)

# Network Architecture is same as given in Paper https://arxiv.org/pdf/1609.04802.pdf
class Discriminator(object):

    def __init__(self, image_shape):
        
        self.image_shape = image_shape
    

    def discriminator(self,num_filters=32):
        
        x_in = Input(shape=self.image_shape)
        x = discriminator_block(x_in, num_filters, batchnorm=False)
        x = discriminator_block(x, num_filters, strides=2)
        x = discriminator_block(x, num_filters * 2)
        x = discriminator_block(x, num_filters * 2, strides=2)
        x = discriminator_block(x, num_filters * 4)
        x = discriminator_block(x, num_filters * 4, strides=2)
        x = discriminator_block(x, num_filters * 8)
        x = discriminator_block(x, num_filters * 8, strides=2)
  
        ###FOR RAGAN
        x = GlobalAveragePooling2D()(x)
        x = Dropout(0.4)(x)
        x = Dense(num_filters * 16)(x)
        x = Dropout(0.2)(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Dense(1)(x)
       
        return Model(x_in, x)
        
image_shape = (dimtest,dimtest,3)
  
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def RAGAN_discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(np.reshape(np.random.uniform(size = len(real_output), low = 0.7, high = 1.2), (len(real_output),1)), real_output - fake_output)
    fake_loss = cross_entropy(np.reshape(np.random.uniform(size = len(fake_output), low = 0.0, high = 0.3), (len(fake_output),1)), fake_output - real_output)

    total_loss = real_loss + fake_loss
    return total_loss

def RAGAN_generator_loss(real_output, fake_output):
    return cross_entropy(np.reshape(np.random.uniform(size = len(fake_output), low = 0.8, high = 1.1), (len(fake_output),1)), fake_output - real_output)
                       

def train_generator(epochs, batch_size, model_save_dir,x_train_lr,x_test_lr,x_train_hr,x_test_hr):
    loss = VGG_LOSS(image_shape)  

    batch_count = int(x_train_hr.shape[0] / batch_size)

    shape = (dimtrain,dimtrain,3)
    optimizer = get_optimizer()
    #optimizer = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[10000], values=[1e-4, 1e-5])
    generator = Generator(shape).generator()
    generator.compile(loss=['mae'], optimizer=optimizer)
    
    datagentrain = tf.keras.preprocessing.image.ImageDataGenerator()
    datagenval = tf.keras.preprocessing.image.ImageDataGenerator()

    datagentrain.fit(x_train_lr)
    datagenval.fit(x_test_lr)

    train_generator = datagentrain.flow(
        x_train_lr, x_train_hr, batch_size=batch_size, 
    )

    val_generator = datagenval.flow(
        x_test_lr, x_test_hr, batch_size=batch_size,
    )

    callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
    
    history = generator.fit(train_generator,
             steps_per_epoch=len(x_train_hr) // batch_size, epochs=epochs,
                        validation_data=val_generator,validation_steps=len(x_test_hr)//batch_size,
                        shuffle=True, callbacks=[callback]
                        )#callbacks=[mcp_save] ,callbacks=[TqdmCallback(verbose=0)]
    
    generator.save_weights(model_save_dir + 'Proyecciongen_model_only_weights.h5')
    del generator
    gc.collect()
    gc.collect()
    gc.collect()
    
# default values for all parameters are given, if want defferent values you can give via commandline
# for more info use $python train.py -h

def train(epochs, batch_size, model_save_dir,x_train_lr,x_test_lr,x_train_hr,x_test_hr):

    loss = VGG_LOSS(image_shape)  

    batch_count = int(x_train_hr.shape[0] / batch_size)

    shape = (dimtrain,dimtrain,3)

    optimizer = get_optimizer()
    generator = Generator(shape).generator()
    generator.compile(loss=['mae'], optimizer=optimizer)
    generator.train_on_batch(x_train_lr[:1], x_train_hr[:1])
    
    generator.load_weights(model_save_dir+'Proyecciongen_model_only_weights.h5')
    discriminator = Discriminator(image_shape).discriminator()
    
    #optimizer = get_optimizer()
    scheduledis = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[10000], values=[1e-4, 1e-5])
    schedulegen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[10000], values=[1e-4, 1e-5])
    generator_optimizer = Adam(learning_rate=schedulegen)
    
    discriminator_optimizer = Adam(learning_rate=scheduledis)

    pls_metric = tf.keras.metrics.Mean()
    dls_metric = tf.keras.metrics.Mean()
    
    half_batch = int(batch_size / 2)
    
    print("Entenamos ESRGan")
    for e in range(1, epochs+1):
        print ('-'*15, 'Epoch %d' % e, '-'*15)
        for step in tqdm(range(batch_count)): 
            rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

            hr = x_train_hr[rand_nums]
            lr = x_train_lr[rand_nums]
            
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                lr = tf.cast(lr, tf.float32)
                hr = tf.cast(hr, tf.float32)
           
                # Forward pass
                sr = generator(lr, training=True)
                hr_output = discriminator(hr, training=True)
                sr_output = discriminator(sr, training=True)
        
                # Compute losses
                con_loss = loss.vgg_loss(hr, sr)
     
                gen_loss = RAGAN_generator_loss(hr_output,sr_output)
                #gen_loss = generator_loss(hr_output,sr_output)
            
                ## 0.001 * con_loss para block2_conv2 
                perc_loss = con_loss + 0.001 * gen_loss  #0.001 * con_loss si 2.2, con_loss si 5.4
                
                #disc_loss = discriminator_loss(hr_output, sr_output)
                disc_loss = RAGAN_discriminator_loss(hr_output, sr_output)

            # Compute gradient of perceptual loss w.r.t. generator weights 
            gradients_of_generator = gen_tape.gradient(perc_loss, generator.trainable_variables)
            # Compute gradient of discriminator loss w.r.t. discriminator weights 
            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

            # Update weights of generator and discriminator
            generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
            
        pls_metric(perc_loss)
        dls_metric(disc_loss)

        print(f'{e}/{epochs}, perceptual loss = {pls_metric.result():.4f}, discriminator loss = {dls_metric.result():.4f}')
        pls_metric.reset_states()
        dls_metric.reset_states()
        
        #show images
        plt.figure(figsize=(20,20))
        plt.subplot(1,3,1)
        plt.imshow((lr[0] + 1) /2)
        plt.subplot(1,3,2)
        plt.imshow((sr[0][224:288,224:288] + 1) /2)
        plt.subplot(1,3,3)
        plt.imshow((hr[0] + 1) /2)
        plt.show()
        
        if e % 50 == 0:
            print("Save generator")
            generator.save_weights(model_save_dir + 'gen_model'+str(e)+'.h5')
            discriminator.save_weights(model_save_dir + 'dis_model'+str(e)+'.h5')

我改变了上采样方法(convtranspose2D 或 conv2D)但得到了相同的结果。不知道是不是Vgg19的错误。我还修改了损失函数。使用 vgg19 时,我会在激活层之前执行此操作。 “train_generator”中的第一个训练生成器,在“train”中的 Esrgan 模型之后。

请帮助我。

谢谢。

keras deep-learning artifacts vgg-net subpixel
© www.soinside.com 2019 - 2024. All rights reserved.