我一直在使用 Tensorflow 和 Tensorflow 数据集在 python 中训练用于多类语义分割的unet模型。
我注意到我的一个班级在培训中的代表性似乎不足。经过一些研究后,我发现了样本权重,并认为这可能是解决我的问题的一个好方法,但我一直无法破译有关如何使用它的文档或找到正在使用它的示例。
有人可以帮助解释样本权重如何与训练数据集一起发挥作用,或者向我指出一个正在实施的示例吗?或者甚至 model.fit 函数期望什么类型的输入也会有帮助。
来自 tf.keras model.fit()
的
文档:
sample_weight
[...] 当 x 是数据集、生成器或
实例时,不支持此参数,而是提供样本权重作为 x 的第三个元素。keras.utils.Sequence
这是什么意思?这在官方
文档图例之一中的
Dataset
案例中得到了证明:
sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0
# Create a Dataset that includes sample weights
# (3rd element in the return tuple).
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))
# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
model = get_compiled_model()
model.fit(train_dataset, epochs=1)
请参阅链接以获取完整的示例。