出口tflite模型预测不正确输出

问题描述 投票:0回答:1

我有一个训练有素的定制Tensorflow模型,并希望在通过火力地堡MLKit我的iOS应用程序来使用它。该模型与一个隐含层,其是这样的简单的4英寸-4出神经网络。

num_data_input = 4
num_units = 12
num_display = 4

xd = tf.placeholder(tf.float32, [None, num_data_input])

w1 = tf.Variable(tf.truncated_normal([num_data_input, num_units],dtype=tf.float32))
b1 = tf.Variable(tf.zeros([num_units],dtype=tf.float32))
hidden1 = tf.nn.sigmoid(tf.matmul(xd, w1) + b1)

w0 = tf.Variable(tf.zeros([num_units, num_display],dtype=tf.float32))
b0 = tf.Variable(tf.zeros([num_display],dtype=tf.float32))
p = tf.nn.softmax(tf.matmul(hidden1, w0) + b0)

ref = tf.placeholder(tf.float32, [None,num_display])
loss = -tf.reduce_sum(ref * tf.log(p))

train_step = tf.train.AdamOptimizer().minimize(loss)
correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(ref, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

该模型的目的是通过使用单独放置在一间预测你最接近中在房间设置4个锚定点其中一个4的源信号。因此,输出应显示为每个定位点的概率。 (源信号位置和锚点之间的关系不是直接的,这就是为什么我在考虑使用机器学习,顺便说一句。)

约10000次迭代训练后,loss已下降到约0.3。 (是不是不够好,但在这里不是一个问题)。

训练后,我得到了这样的价值观。

print("input", dataarray[0][0])
print("output", sess.run(p, {xd: np.array(dataarray[0][0], dtype=np.float32).reshape(1,4)}))
# Results in:
# input [-87.43277416700528, -81.06589379945419, -71.74611110703701, -71.10851819430701]
# output [[1.5792685e-14 1.7755997e-01 7.4530774e-01 7.7132232e-02]]

print("input", dataarray[10][0])
print("output", sess.run(p, {xd: np.array(dataarray[10][0], dtype=np.float32).reshape(1,4)}))
# Results in:
# input [-86.87348060585144, -79.92684778533435, -71.24158331694396, -71.81342917898614]
# output [[1.30361505e-14 1.73598051e-01 7.51829445e-01 7.45724738e-02]]

除了它的正确与否,你至少可以看到它是针对不同的输入报告不同的值。

有了这个结果,我在Python代码做了tflite模型saved_model

tf.saved_model.simple_save(sess, "model", inputs={"input": xd}, outputs={"output": p})

和从命令行tflite_convert

tflite_convert --output_file=tmp/model.tflite --saved_model_dir=model

然后将它导入在通过云我斯威夫特项目:

    let conditions = ModelDownloadConditions(isWiFiRequired: true, canDownloadInBackground: true)
    let cloudModelSource = CloudModelSource(
        modelName: "my-model",
        enableModelUpdates: false,
        initialConditions: conditions,
        updateConditions: conditions
    )
    let registrationSuccessful = ModelManager.modelManager().register(cloudModelSource)

    let options = ModelOptions(
        cloudModelName: "my-model",
        localModelName: nil)
    interpreter = ModelInterpreter.modelInterpreter(options: options)

    ioOptions = ModelInputOutputOptions()
    do {
        try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 4])
        try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 4])
    } catch let error as NSError {
        print("Failed to set input or output format with error: \(error.localizedDescription)")
    }

和执行预测:

    // 1st value
    let inputValues: [Double] = [-78.07635984967995, -76.68000728404165, -73.98016027165527, -74.77428875130332] 
    // 2nd value
    // let inputValues: [Double] = [-86.87348060585144, -79.92684778533435, -71.24158331694396, -71.81342917898614]

    let inputs = ModelInputs()
    let converted: [Float32] = inputValues.map { Float32($0) }
    do {
        try inputs.addInput([converted])
    } catch let error {
        print("Failed to add input: \(error)")
    }

    interpreter.run(inputs: inputs, options: ioOptions) { (outputs, error) in
        guard error == nil, let outputs = outputs else { return }
        do {
            if let ov = try outputs.output(index: 0) as? [[NSNumber]] {
                print("output = \(ov)") 
                // output = [[0.089901, 0.2951571, 0.2564065, 0.3585353]]
                // ^ Different from above result in python!
                //   And gives me the same value even when the input value is switched to "2nd value" above
            }
        } catch let error {
            print("output retrieval error: \(error)")
        }
    }
}

首先,从银行代码的输出值是从所述一个在Python不同。而最重要的是,我试图与他们都给予我同样的价值的一些不同的值。我也试图与现实世界的信号值,但它给了我同样的值在几乎所有情况。只有当我给它一些极端值远出我的预期范围的它显示了不同的值。

你看在我的Python代码或者SWIFT代码的任何问题吗?在我tflite模型转换或遗漏什么吗?

任何信息都是有帮助的。

谢谢

ios firebase-mlkit
1个回答
0
投票

部分解决。或者准确说,这是一个不同的问题。

通过提供本地束模型,它的工作如预期。

我发现,即使我更换了模型的云使用前代机型的应用不断。

我没有上面提到的,但我得到上述结果之前尝试了不同的tflite模型。我没太注意关于它的第一个结果,因为我知道该模型是相当不完整的。如果我更注重它我会较早发现,结果还是没改变的。

反正我现在必须弄清楚为什么预期从云检索将无法正常工作,但我希望这有助于信息别人。

© www.soinside.com 2019 - 2024. All rights reserved.