`output_signature` 必须包含属于 `tf.TypeSpec` 子类的对象,但发现 <class 'list'> 不是

问题描述 投票:0回答:1

发生错误代码

hist = model.fit(
            data_gen_train.generate(),
            steps_per_epoch=2 if params['quick_test'] else data_gen_train.get_total_batches_in_data(),
            validation_data=data_gen_test.generate(),
            validation_steps=2 if params['quick_test'] else data_gen_test.get_total_batches_in_data(),
            epochs=1,
            verbose=0
        )

完整的错误日志

File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/tensorflow/python/data/ops/from_generator_op.py", line 124, in _from_generator
    raise TypeError(f"`output_signature` must contain objects that are "
TypeError: `output_signature` must contain objects that are subclass of `tf.TypeSpec` but found <class 'list'> which is not.

现在我正在本地计算机上练习声音事件检测论文。 但是过去使用的是 keras 和 Tensorflow,我在训练时遇到了一些问题。

模型是用 keras 模型制作的,并编译了我的以下代码。

model = Model(inputs=spec_start, outputs=[sed, doa])
model.compile(optimizer=Adam(), loss=['binary_crossentropy', 'mse'], loss_weights=weights)

从错误代码来看,可能存在于哪里? .fit() 方法中的output_signature 是什么

python tensorflow keras
1个回答
0
投票

我使用的是tensorflow 2.16,在返回列表时遇到了这个问题,如下所示:

def __getitem__(self, index):
    X1, X2, y = get_data()
    return [X1, X2], y

为了解决这个问题,我将列表更改为元组,问题解决了:

def __getitem__(self, index):
    X1, X2, y = get_data()
    return (X1, X2), y
© www.soinside.com 2019 - 2024. All rights reserved.