CNN for Sentiment Analysis使用Android的TFLearn模型来分类用户输入

问题描述 投票:-3回答:1

我有一个CNN模型用于文本分类,它使用预先训练的手套嵌入。我已经冻结了针对推理优化的图形并在android studio上使用它。问题是当我尝试将权重传递给模型进行推理时。我有一个JSON文件,其中包含单词和嵌入之间的键值对,我用它来创建用户输入的文本的嵌入输入。我已经可以从JSON文件中获取嵌入但是当我尝试将它提供给图形进行推理,它给出了以下错误:

java.lang.IllegalArgumentException: indices[0,3891] = -2 is not in [0, 
7459)
[[Node: EmbeddingLayer/embedding_lookup = Gather[Tindices=DT_INT32, 
Tparams=DT_FLOAT, _class=["loc:@EmbeddingLayer/W"], 
validate_indices=false, 
_device="/job:localhost/replica:0/task:0/device:CPU:0"] 

(嵌入图层/带/读取,嵌入图层/强制转换)]]

Android代码在我的GitHub https://github.com/sushiboo/testNN1

给我问题的主要代码是Classify方法:

private void classify(float[] input){
TFInference = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);

TFInference.feed(INPUT_NODE, input, 1, input.length);
TFInference.run(OUTPUT_NODES);
float[] resu = new float[2];
TFInference.fetch(OUTPUT_NODE, resu);
tvResult.setText("Programmer: " + Float.toString(resu[0]) + "\n Construction" +  Float.toString(resu[1]));
Log.e("Result: ", Float.toString(resu[0]));
}

问题在于

TFInference.run(OUTPUT_NODES);

在错误消息上,数字'7459'表示嵌入层的输入维度。

我真的很困惑这里发生了什么,但我知道指数[0,3891] = -2在这方面发挥了作用。

java android tensorflow tflearn
1个回答
-2
投票

问题在于模型人员。我已修复此问题,现在我遇到了另一个错误。

最新问题
© www.soinside.com 2019 - 2024. All rights reserved.