使用fit_generator批量训练模型

问题描述 投票:0回答:1

我的模型有100 000个图像训练样本,如何修改下面的代码以批量训练它?使用model.fit_generator我必须在generator函数内部指定它:

def data_generator(descriptions, features, n_step, max_sequence):
    # loop until we finish training
    while 1:
        # loop over photo identifiers in the dataset
        for i in range(0, len(descriptions), n_step):
            Ximages, XSeq, y = list(), list(),list()
            for j in range(i, min(len(descriptions), i+n_step)):
                image = features[j]
                # retrieve text input
                desc = descriptions[j]
                # generate input-output pairs
                in_img, in_seq, out_word = preprocess_data([desc], [image], max_sequence)
                for k in range(len(in_img)):
                    Ximages.append(in_img[k])
                    XSeq.append(in_seq[k])
                    y.append(out_word[k])
            # yield this batch of samples to the model
            yield [[array(Ximages), array(XSeq)], array(y)]

我的model.fit_generator代码:

model.fit_generator(data_generator(texts, train_features, 1, 150), 
                    steps_per_epoch=1500, epochs=50, callbacks=callbacks_list, verbose=1)

任何帮助都会很棒,我正在16GB V100 Tesla云云上进行培训

tensorflow machine-learning keras
1个回答
0
投票

生成器已经退回批次。

每个yield是一个批次。您完全可以根据自己的需要来设计批处理发生器。

在您的代码中,批处理大小为n_step

© www.soinside.com 2019 - 2024. All rights reserved.