on_epoch_end在fit_generator中的所有steps_per_epoch之后都未调用

问题描述 投票:0回答:1

我已经构建了自己的DataGenerator,以便在keras.fit_generator中使用它

在我的训练脚本中,我从2个路径列表中实例化了2个生成器。一种用作训练基因,另一种用作有效基因。 On_epoch_end是针对训练生成器的调用,而不是有效生成器的调用。我需要on_epoch_end回调才能重置我的音量索引,否则在第二个时期,我会收到错误:IndexError:列表索引超出范围(加载卷时)

training_generator = DataGenerator.DataGenerator('TrainingLoader',
                                                 list_id,
                                                 mask_id,
                                                 n_cube=n_cube_train,
                                                 batch_size=2,
                                                 dim=(64, 64, 64),
                                                 n_channels=1,
                                                 n_classes=3,
                                                 shuffle=True,
                                                 augmentation=True,
                                                 overlap=4,
                                                 rotation=0,
                                                 translation=0,
                                                 scaling=1,
                                                 channel_first=False,
                                                 depth_first=False)

validation_generator = DataGenerator.DataGenerator('ValidLoader',
                                                   valid_list_id,
                                                   valid_mask_list_id,
                                                   n_cube=n_cube_valid,
                                                   batch_size=2,
                                                   dim=(64, 64, 64),
                                                   n_channels=1,
                                                   n_classes=3,
                                                   shuffle=False,
                                                   augmentation=False,
                                                   overlap=4,
                                                   rotation=0,
                                                   translation=0,
                                                   scaling=1,
                                                   channel_first=False,
                                                   depth_first=False)

model3.fit_generator(generator=training_generator,
                     epochs=1000,
                     validation_data=validation_generator,
                     validation_freq=1,
                     verbose=1,
                     shuffle=False,
                     workers=0,
                     callbacks=callback)

================================================ ============================DataGenerator类(keras.utils.Sequence):'为Keras生成数据'

def __init__(self, name, list_id, mask_id, n_cube, batch_size=5, dim=(64, 64, 64), n_channels=1,
             n_classes=10, shuffle=False, augmentation=False, overlap=4, rotation=10, translation=10, scaling=0.9,
             channel_first=False, depth_first=False):
    """
    Initialization of the class
    ---
    :param list_id:
    :param labels:
    :param batch_size: Number of data to load per batch
    :param dim: Dimension of the data
    :param n_channels:  Number of information per pixel. 1-Grayscale 3-RGB
    :param n_classes: Number of mask
    :param shuffle: Boolean for shuffling the order of the loading data
    """
    self.name = name
    self.list_id = list_id
    self.mask_id = mask_id
    self.batch_size = batch_size
    self.dim = dim
    self.overlap = overlap
    self.n_channels = n_channels
    self.n_classes = n_classes
    self.shuffle = shuffle
    self.augmentation = augmentation
    self.rotation = rotation
    self.translation = translation
    self.scaling = scaling
    self.on_epoch_end()
    self.offset = 0
    self.volume_index = 0
    self.cube_index = 0
    self.volume_cube_index = []
    self.n_cube = n_cube
    self.channel_first = channel_first
    self.depth_first = depth_first

def __len__(self):
    """
    Function that calculate the number of batch needed per epoch
    ---
    :return: The number of batch per epoch
    """
    #print(self.name, int(np.floor(self.n_cube / self.batch_size)))
    #return int(np.floor(self.n_cube / self.batch_size))

    return (self.n_cube + self.batch_size - 1) // self.batch_size  # round up

def __getitem__(self, index):
    # Generate data
    return self.__data_generation()

def on_epoch_end(self):
    """
    Activate at the beginning and at the end of every epoch.
    Shuffle the ids id shuffle = True
    ---
    :return:  None
    """
    self.offset = 0
    self.volume_index = 0
    self.cube_index = 0

def __data_generation(self):
    """
    :return: a training batch cubes=(n
    """
    for i in range(0, self.batch_size):
        # Verify if load volume is already done
        if self.cube_index == len(self.volume_cube_index):
            self.load_volume()

        if i == 0:
            cubes = dt.get_cube(self.volume,
                                self.volume_cube_index[self.cube_index],
                                self.dim[0],
                                self.dim[2])
            masks = dt.get_cube_mask(self.mask,
                                     self.volume_cube_index[self.cube_index],
                                     self.dim[0],
                                     self.dim[2],
                                     self.n_classes)
        elif i > 0:
            temp_cube = dt.get_cube(self.volume,
                                    self.volume_cube_index[self.cube_index],
                                    self.dim[0],
                                    self.dim[2])
            cubes = np.concatenate((temp_cube, cubes), axis=0)

            temp_mask = dt.get_cube_mask(self.mask,
                                         self.volume_cube_index[self.cube_index],
                                         self.dim[0],
                                         self.dim[2],
                                         self.n_classes)
            masks = np.concatenate((temp_mask, masks), axis=0)

        self.cube_index += 1

    return cubes, masks

def load_volume(self):
    self.volume, self.mask = dt.get_process_volume(data_dir=self.list_id[self.volume_index],
                                                   mask_dir=self.mask_id[self.volume_index],
                                                   kernel_widht=self.dim[0],
                                                   kernel_depth=self.dim[2],
                                                   overlap=self.overlap,
                                                   rotation=self.rotation,
                                                   translation=self.translation,
                                                   scaling=self.scaling,
                                                   augmentation=self.augmentation)

    self.volume_cube_index = dt.get_cube_index(image=self.volume,
                                               resolution=self.dim[0],
                                               depth=self.dim[2],
                                               overlap=self.overlap,
                                               shuffle=self.shuffle)

    # Reset the cube index, update volume index
    self.cube_index = 0
    self.volume_index += 1
python tensorflow keras deep-learning training-data
1个回答
0
投票

我的总体积不能被我的批量大小整除。因此它没有触发on_epoch_end调用。体积的奇数除以2

© www.soinside.com 2019 - 2024. All rights reserved.