Tensorflow:指标尚未构建错误

问题描述 投票:0回答:1

我在使用 Tensorflow 训练 CNN 模型执行多类分类任务时遇到问题。当尝试拟合模型时,我收到以下回溯错误:

Traceback (most recent call last):
  File "/app/main.py", line 182, in <module>
    model.fit(
  File "/app/models/cnn.py", line 55, in fit
    super().fit(train, validation, num_epochs, steps)
  File "/app/models/neural_network.py", line 85, in fit
    self._history = self._model.fit(
                    ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py", line 357, in result
    raise ValueError(
ValueError: Cannot get result() since the metric has not yet been built.

该错误似乎与拟合模型之前未构建的指标有关。 有人可以向我解释为什么会发生此错误以及如何解决它吗?我搜索了 Tensorflow 文档,但没有找到任何有关如何避免此错误的提示。产生此错误的代码片段如下:

class NeuralNetwork:

    def __init__(self, name: str, classes, shape, batch_size: int = 32, logger: Optional[Logger] = None):

        # Define model parameters
        self._classes = classes
        self._shape = shape
        self._batch_size = batch_size

        # Define model name and output
        self._name = name
        self._output = 'results/{}.keras'.format(self._name)

        # Define logger
        if logger is None:
            self._logger = Logger(name=self._name, level='INFO').logger
        else:
            self._logger = logger.logger

        self._already_exists = False
        self._history = None
        self._metrics = Metrics(name=self._name, output_dir=self._output)

        # Load model if exists
        if os.path.exists(self._output):
            self._model = keras.models.load_model(self._output)
            self._already_exists = True
        else:
            self._model = Model()
    
  def fit(self, train, validation, num_epochs, steps: list):

        # Compile model
        optimizer = Adam(learning_rate=0.001)
        self._model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['categorical_accuracy'])

        # Callbacks definition
        reduce_lr = ReduceLROnPlateau( monitor='val_categorical_accuracy', mode='max',factor=0.1, min_lr=0.0000001, patience=5, verbose=1)

        early_stop = EarlyStopping(monitor='val_categorical_accuracy', mode='max',           patience=10,verbose=1, restore_best_weights=True)

        model_checkpoint = ModelCheckpoint( filepath=self._output,  monitor='val_categorical_accuracy',save_best_only=True, mode='max')

        # Fit model
        self._history = self._model.fit(
            train,
            steps_per_epoch=steps[0],  
            epochs=num_epochs,
            validation_data=validation,
            validation_steps=steps[1],  
            verbose=2,
            callbacks=[
                reduce_lr,
                early_stop,
                model_checkpoint
            ]
        )

对于数据集创建,我遵循相当标准的程序:

def dataset_creation(...):
        
    ...

    for i, dataset in enumerate([self._train, self._validation, self._test]):

            # Extraction of data indexes of from the dataframe
            indexes = [str(index) for index in dataset.index]

            # Extraction of labels from the dataframe
            labels = dataset[[''.join(['label_', str(s).lower()]) for s in self._encoder.classes_]].astype(int)

            # Creation of the dataset with indexes and label
            d = tf.data.Dataset.from_tensor_slices((indexes, labels))

            # Apply the loader function to each element of the datasets
            d = d.map(lambda index, label: (
                tf.numpy_function(self.loader, [index, input_size], tf.float32),
                label
            ), num_parallel_calls=self._cores)

            if i in [0, 1]:
                d = d.repeat()

            # Operations for shuffling and batching of the dataset
            if shuffle:
                d = d.shuffle(len(indexes))

            d = d.batch(batch_size=batch_size)
            d = d.prefetch(buffer_size=1)
    
    ...

任何有关如何解决此问题的见解或建议将不胜感激。

谢谢!

python tensorflow deep-learning
1个回答
0
投票

你能得到答案吗?我也遇到同样的问题,在平台上找不到答案。

谢谢

© www.soinside.com 2019 - 2024. All rights reserved.