如何在张量流数据集上使用样本权重?

问题描述 投票:0回答:1

我一直在使用 Tensorflow 和 Tensorflow 数据集在 python 中训练用于多类语义分割的unet模型。

我注意到我的一个班级在培训中的代表性似乎不足。经过一些研究后,我发现了样本权重,并认为这可能是解决我的问题的一个好方法,但我一直无法破译有关如何使用它的文档或找到正在使用它的示例。

有人可以帮助解释样本权重如何与训练数据集一起发挥作用,或者向我指出一个正在实施的示例吗?或者甚至 model.fit 函数期望什么类型的输入也会有帮助。

tensorflow machine-learning keras tensorflow-datasets tf.keras
1个回答
15
投票

来自 tf.keras model.fit()

文档

sample_weight

[...] 当 x 是数据集、生成器或

keras.utils.Sequence
实例时,不支持此参数,而是提供样本权重作为 x 的第三个元素。

这是什么意思?这在官方

文档图例之一中的 
Dataset 案例中得到了证明:

sample_weight = np.ones(shape=(len(y_train),)) sample_weight[y_train == 5] = 2.0 # Create a Dataset that includes sample weights # (3rd element in the return tuple). train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight)) # Shuffle and slice the dataset. train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64) model = get_compiled_model() model.fit(train_dataset, epochs=1)
请参阅链接以获取完整的示例。

© www.soinside.com 2019 - 2024. All rights reserved.