我已经遵循tensorflow服务教程mnist_saved_model.py并尝试训练和导出text-cnn-classifier模型管道是
*embedding layer -> cnn -> maxpool -> cnn -> dropout -> output layer
Tensorflow数据输入:
data_in = tf.placeholder(tf.int32,[None, sequence_length] , name='data_in')
变成了
serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length],
dtype=tf.int64),}
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
# use tf.identity() to assign name
data_in = tf.identity(tf_example['x'], name='x')
这适用于训练阶段但在测试时它告诉AbortionError(代码= StatusCode.INVALID_ARGUMENT,details =“预期arg [0]为int64但提供了字符串”)
我对以上几行感到困惑
feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length],
dtype=tf.int64),}
我把线改成了
feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length],
dtype=tf.string),}
但它在训练时出现以下错误:
Traceback (most recent call last):
File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/tf_serving/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.py", line 222, in <module>
embedded_chars = tf.nn.embedding_lookup(W, data_in)
File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/ops/embedding_ops.py", line 122, in embedding_lookup
return maybe_normalize(_do_gather(params[0], ids, name=name))
File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/ops/embedding_ops.py", line 42, in _do_gather
return array_ops.gather(params, ids, name=name)
File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/ops/gen_array_ops.py", line 1179, in gather
validate_indices=validate_indices, name=name)
File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py", line 589, in apply_op
param_name=input_name)
File "/serving/bazel-bin/tensorflow_serving/example/twitter-sentiment-cnn_saved_model.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'indices' has DataType string not in list of allowed values: int32, int64
你的代码错了:
serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
这意味着你的输入是string
,例如句子的单词。因此:
feature_configs = {'x': tf.FixedLenFeature(shape=[sequence_length],
dtype=tf.int64),}
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
在我看来,这并不意味着什么,因为你不会将词汇转移string
到int
。你需要加载你的火车数据的词汇来获得单词索引!