方法`export_savedmodel`的参数`serving_input_fn`是什么?

问题描述 投票:0回答:1

我正在尝试训练Char RNN,并在训练后导出/保存模型,以便可以在推理时使用它。这是模型:

def char_rnn_model(features, target):
    """Character level recurrent neural network model to predict classes."""
    target = tf.one_hot(target, 15, 1, 0)
    #byte_list = tf.one_hot(features, 256, 1, 0)
    byte_list = tf.cast(tf.one_hot(features, 256, 1, 0), dtype=tf.float32)
    byte_list = tf.unstack(byte_list, axis=1)

    cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
    _, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32)


    logits = tf.contrib.layers.fully_connected(encoding, 15, activation_fn=None)
    #loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
    loss = tf.contrib.losses.softmax_cross_entropy(logits=logits, onehot_labels=target)

    train_op = tf.contrib.layers.optimize_loss(
      loss,
      tf.contrib.framework.get_global_step(),
      optimizer='Adam',
      learning_rate=0.001)

    return ({
      'class': tf.argmax(logits, 1),
      'prob': tf.nn.softmax(logits)
    }, loss, train_op)

和培训部分:

# train
model_dir = "model"
classifier = learn.Estimator(model_fn=char_rnn_model,model_dir=model_dir)
count=0
n_epoch = 20
while count<n_epoch:
        print("\nEPOCH " + str(count))
        classifier.fit(x_train, y_train, steps=1000,batch_size=10)
        y_predicted = [
              p['class'] for p in classifier.predict(
              x_test, as_iterable=True,batch_size=10)
        ]
        score = metrics.accuracy_score(y_test, y_predicted)
        print('Accuracy: {0:f}'.format(score))
        count+=1

([x_train是一个uint8形状的数组(16639,100))

Tensorflow文档介绍了似乎可以执行我想要的方法的export_savedmodel。但我不理解第二个参数serving_input_fn。应该是什么? classifier.export_savedmodel(output_dir, ???)

我正在使用Tensorflow 1.8.0和python 2.7.14。

这与this thread有关。

python tensorflow machine-learning inference
1个回答
0
投票

[尝试查看是否有Tensorflow文档所说的model.save方法。否则,您可以只保存权重,然后将其重新加载到默认模型中。您可以写:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
© www.soinside.com 2019 - 2024. All rights reserved.