使用 tf.estimator.DNNClassifier 调整超参数

问题描述 投票:0回答:1

我使用 DNNClassifier 类实现了以下模型。模型参数化如下

classifier = tf.estimator.DNNClassifier(
       hidden_units=[60, 30, 20],
       feature_columns=feature_columns,
       n_classes=len(labels),
       label_vocabulary=labels,
       batch_norm=True,
       optimizer=lambda: tf.keras.optimizers.Adam(
           learning_rate=tf.compat.v1.train.exponential_decay(
               learning_rate=0.1,
               global_step=tf.compat.v1.train.get_global_step(),
               decay_steps=10000,
               decay_rate=0.96)
       )
)

现在我想做一些超参数调整(例如学习率、单元数等)。

DNNClassifier
,是一个预制的估计器类,继承自
Estimator
类。

但是,虽然

Estimator
params
参数来传递超参数,但
DNNClassifier
却没有。

那么使用

DNNClassifier
进行超参数调整的首选方法应该是什么?

tensorflow2.0 tensorflow-estimator
1个回答
0
投票

首先,您的估算器需要一个输入函数,假设您使用 pandas 数据帧来保存数据,(data_df 和 label_df 是数据帧)您可以编写如下内容:

def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):
    def input_function():
        ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))
        if shuffle:
            ds = ds.shuffle(1024)
        ds = ds.batch(batch_size).repeat(num_epochs)
        return ds

    return input_function

然后使用上面的代码创建两个输入函数,一个用于训练,一个用于验证,如下所示:

train_input_fn = make_input_fn(X_train, y_train)
val_input_fn = make_input_fn(X_val, y_val, num_epochs=1, shuffle=False)

最后训练您定义的分类器并使用验证集评估它。多次运行此管道以调整您的超参数

classifier = tf.estimator.DNNClassifier(
       hidden_units=[60, 30, 20],
       feature_columns=feature_columns,
       n_classes=len(labels),
       label_vocabulary=labels,
       batch_norm=True,
       optimizer=lambda: tf.keras.optimizers.Adam(
           learning_rate=tf.compat.v1.train.exponential_decay(
               learning_rate=0.1,
               global_step=tf.compat.v1.train.get_global_step(),
               decay_steps=10000,
               decay_rate=0.96)
       )
)
# Train Classifier.
classifier.train(train_input_fn)

# Evaluate Classifier.
result = classifier.evaluate(val_input_fn)
print(result)
© www.soinside.com 2019 - 2024. All rights reserved.