我正在使用 AWS Sagemaker 训练 Tensorflow 图像分类模型。在训练期间,我收到以下日志消息:
火车数据集的基数:1492 训练数据集中的类示例数量:{'Approved': 36, 'Rejected': 36} 验证数据集的基数:328 验证数据集中的类示例数量:{'Approved': 9, 'Rejected': 9}
我的理解是,类示例的数量应该相加才能给出基数。就我而言,我的训练数据集中有 36 个“已拒绝”和 1456 个“已批准”。然而,日志似乎暗示每个类仅存在 36 个示例。同样,在验证集中有 9 个“已拒绝”和 319 个“已批准”,但每个只能找到 9 个示例。
我想了解为什么示例之和不等于基数。我怀疑 Tensorflow 假设我有一个平衡的数据集,并且我需要以某种方式指定这些类是不平衡的。是这种情况吗?如果是的话,我如何将类别权重应用于我的估计器。
供参考,我的估计量定义如下:
ic_estimator = Estimator(
role=role_arn,
image_uri=train_image_uri,
source_dir=train_source_uri,
model_uri=train_model_uri,
entry_point="transfer_learning.py",
instance_count=1,
input_mode='FastFile',
instance_type=train_instance_type,
max_run=3600,
hyperparameters=hyperparameters,
binary_mode=True,
objective_metric_name='validation:recall', # Sets the metric to determine the best model by
metric_definitions=[{'Name': 'validation:recall', 'Regex': "val_recall: (\d+\.\d+)"}],
output_path=s3_output_path,
base_job_name=training_job_name)
由于您在“批准”和“拒绝”之间看到相同的数字,我建议检查您的训练脚本,并确保您将正确的数据传递给 Tensorflow。
另外,请尝试查找谁在打印“火车数据集的基数:”日志。我简单地检查了Tensorflow的源代码,但没有找到该日志行。这个日志打印代码本身可能有问题。