自定义对象检测模型中的不同输出顺序导致 Android 应用程序出现错误

问题描述 投票:0回答:1

我按照以下文档中概述的步骤进行操作:https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html 来训练我的自定义对象检测模型。作为参考,我使用的是 TF 2.10。但是,将其转换为 tflite 模型并在 Java 的 Android 应用程序中实现后,出现以下错误:

EXCEPTION: Failed on interpreter inference -> Cannot copy from a TensorFlowLite tensor (StatefulPartionedCall:1) with shape [1,10] to a Java object with shape [1,10,4]. 

在 TensorFlow 2.6 之前,元数据顺序为框、类、分数、检测数。现在,它似乎已经变成了分数、框、检测数、类别。

我尝试了两件事:1)降级到 TF2.5 这解决了这个问题,但会引发与其他库的不兼容问题,所以我不喜欢这种方法。 2) 根据here的建议之一,使用metadata writer明确声明输出序列;但是,这仍然会引发与上述相同的异常。加载模型(在元数据写入过程之后)并检查输出详细信息后,我看到以下内容:

[{'name': 'StatefulPartitionedCall:1', 'index': 249, 'shape': array([ 1, 10]), 'shape_signature': array([ 1, 10]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:3', 'index': 247, 'shape': array([ 1, 10,  4]), 'shape_signature': array([ 1, 10,  4]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:0', 'index': 250, 'shape': array([1]), 'shape_signature': array([1]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:2', 'index': 248, 'shape': array([ 1, 10]), 'shape_signature': array([ 1, 10]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

显示的形状的顺序仍然与框、类别、分数、检测数量的顺序不匹配。在无需修改android应用程序代码的情况下,是否可以采取其他措施来避免tflite转换过程中输出形状的失真?

如果需要,这里是我用来将 tflite 友好的 saving_model 转换为 tflite 的简单转换脚本:

import tensorflow as tf
import argparse

parser = argparse.ArgumentParser(
    description="tfLite Converter")

parser.add_argument("--saved_model_path",
                    help="", type=str)
parser.add_argument("--tflite_model_path",
                    help="", type=str)

args = parser.parse_args()

converter = tf.lite.TFLiteConverter.from_saved_model(args.saved_model_path)
tflite_model = converter.convert()


with open(args.tflite_model_path, 'wb') as f:
  f.write(tflite_model)
python metadata tensorflow2.0 object-detection tensorflow-lite
1个回答
0
投票

您可以尝试在转换脚本中显式设置输出顺序:

import tensorflow as tf
import argparse

parser = argparse.ArgumentParser(description="tfLite Converter")
parser.add_argument("--saved_model_path", help="Path to the saved model", type=str)
parser.add_argument("--tflite_model_path", help="Path to save the tflite model", type=str)
args = parser.parse_args()

converter = tf.lite.TFLiteConverter.from_saved_model(args.saved_model_path)

# Ensure the converter uses the new experimental converter
converter.experimental_new_converter = True

# Set the output tensor order explicitly
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
tflite_model = converter.convert()

with open(args.tflite_model_path, 'wb') as f:
    f.write(tflite_model)


转换模型后,验证输出详细信息,检查顺序是否与 Android 应用程序期望的匹配


import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path=args.tflite_model_path)
interpreter.allocate_tensors()

# Get output details
output_details = interpreter.get_output_details()
for detail in output_details:
    print(detail)

如果输出顺序仍然不匹配,您可能需要调整 Android 代码以按正确的顺序读取输出。

© www.soinside.com 2019 - 2024. All rights reserved.