我可以在训练期间更改class_weight吗?

问题描述 投票:1回答:1

我想在Keras训练期间改变我的class_weight。

我使用下面的fit_generatorCallback方法。

model.fit_generator(
                decoder_generator(x_train, y_train),
                steps_per_epoch=len(x_train),
                epochs=args.epochs,
                validation_data=decoder_generator(x_valid, y_valid),
                validation_steps=len(x_valid),
                callbacks=callback_list,
                class_weight=class_weights,
                verbose=1)

class Valid_checker(keras.callbacks.Callback):
    def __init__(self, model_name, patience, val_data, x_length):
        super().__init__()
        self.best_score = 0
        self.patience = patience
        self.current_patience = 0 
        self.model_name = model_name
        self.validation_data = val_data
        self.x_length = x_length


    def on_epoch_end(self, epoch, logs={}):
        X_val, y_val = self.validation_data
        y_predict, x_predict = model.predict_generator(no_decoder_generator(X_val, y_val), steps=len(X_val))
        y_predict = np.asarray(y_predict)
        x_predict = np.asarray(x_predict)  

decoder_generatorno_decoder_generator只是定制发电机。

我想在每次纪元结束时改变班级重量。可能吗?那怎么办?

我的数据是不平衡的数据,并且一个类继续过度拟合。

在纪元的最后,我想通过按类计算精度来增加低精度等级的权重。

我能怎么做?

python tensorflow keras deep-learning weight
1个回答
1
投票

一次循环一个时代的简单方法怎么样?

for i in range(args.epochs):
        class_weights = calculate_weights()
        model.fit_generator(
                    decoder_generator(x_train, y_train),
                    steps_per_epoch=len(x_train),
                    epochs=1,
                    validation_data=decoder_generator(x_valid, y_valid),
                    validation_steps=len(x_valid),
                    callbacks=callback_list,
                    class_weight=class_weights,
                verbose=1)

fit_generator中,没有直接的方法为每个时期使用不同的类权重。您可以通过检查model.stop_training的值来合并早期停止

Sample

import numpy as np
from keras.models import Sequential
from keras.layers import Input, Dense
from keras.models import Model
from keras.callbacks import Callback

class Valid_checker(Callback):
    def __init__(self):
        super().__init__()
        self.model = model
        self.n_epoch = 0

    def on_epoch_end(self, epoch, logs={}):   
        self.n_epoch += 1
        if self.n_epoch == 8:
            self.model.stop_training = True

def decoder_generator():
    while True:        
        for i in range(10):
            yield np.random.rand(10,5), np.random.randint(3,size=(10,3))


inputs = Input(shape=(5,))
outputs = Dense(3, activation='relu')(inputs)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

for i in range(10):    
    model.fit_generator(generator=decoder_generator(),
                    class_weight={0:1/3, 1:1/3, 2:1/3},
                    steps_per_epoch=10,                    
                    epochs=1,
                    callbacks=[Valid_checker()])
    if model.stop_training:
        break
© www.soinside.com 2019 - 2024. All rights reserved.