tensorflow服务:feature_configs数据格式的混淆

问题描述 投票:2回答:1

我已经遵循tensorflow服务教程mnist_saved_model.py并尝试训练和导出text-cnn-classifier模型管道是

*embedding layer -> cnn -> maxpool -> cnn -> dropout -> output layer     

Tensorflow数据输入:

data_in = tf.placeholder(tf.int32,[None, sequence_length] , name='data_in')

变成了

  serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
  feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length], 
                     dtype=tf.int64),}
  tf_example = tf.parse_example(serialized_tf_example, feature_configs)
  # use tf.identity() to assign name
  data_in = tf.identity(tf_example['x'], name='x')  

这适用于训练阶段但在测试时它告诉AbortionError(代码= StatusCode.INVALID_ARGUMENT,details =“预期arg [0]为int64但提供了字符串”)

我对以上几行感到困惑

 feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length], 
                    dtype=tf.int64),}

我把线改成了

  feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length], 
                     dtype=tf.string),}

但它在训练时出现以下错误:

Traceback (most recent call last):
  File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/tf_serving/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.py", line 222, in <module>
    embedded_chars = tf.nn.embedding_lookup(W, data_in)
  File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/ops/embedding_ops.py", line 122, in embedding_lookup
    return maybe_normalize(_do_gather(params[0], ids, name=name))
  File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/ops/embedding_ops.py", line 42, in _do_gather
    return array_ops.gather(params, ids, name=name)
  File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/ops/gen_array_ops.py", line 1179, in gather
    validate_indices=validate_indices, name=name)
  File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py", line 589, in apply_op
    param_name=input_name)
  File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'indices' has DataType string not in list of allowed values: int32, int64
tensorflow deep-learning tensorflow-serving
1个回答
0
投票

你的代码错了:

serialized_tf_example = tf.placeholder(tf.string, name='tf_example')

这意味着你的输入是string,例如句子的单词。因此:

feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length], 
                   dtype=tf.int64),}
tf_example = tf.parse_example(serialized_tf_example, feature_configs)

在我看来,这并不意味着什么,因为你不会将词汇转移stringint。你需要加载你的火车数据的词汇来获得单词索引!

© www.soinside.com 2019 - 2024. All rights reserved.