无法在 ObjectDetector 中使用导出到 TensorFlow Lite 的 YOLOv7

问题描述 投票:0回答:1

我有一个在我的自定义数据集上训练的 YOLOv7 模型。我成功地将模型导出到 TensorFlow lite,并能够使用它在 Python 中进行推理。但是,当我尝试在 Android 中使用相同的模型时,使用 带有 TensorFlow lite 的对象检测项目,它会抛出此错误:

java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: The input tensor should have dimensions 1 x height x width x 3. Got 1 x 3 x 640 x 640.

是否可以更改 ObjectDetector 类的输入形状,或导出具有相应输入形状的 YOLOv7 或 YOLOv5 模型?

我尝试调整导出过程以更改 ONNX 模型的输入形状,该模型是从 PyTorch 导出到 TensorFlow Lite 的中间模型,但它引发了此错误:

ONNX export failure: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 640, 640, 3] to have 3 channels, but got 640 channels instead

更新:我使用 onnx2tf 导出具有 NHWC 输入形状的 .tflite 模型。现在Android项目抛出这个错误:

java.lang.RuntimeException: Error occurred when initializing ObjectDetector: Input tensor has type kTfLiteFloat32: it requires specifying NormalizationOptions metadata to preprocess input images.
我找不到使用此doc向元数据添加标准化选项的方法。有什么解决办法吗?

android pytorch object-detection tensorflow-lite yolov7
1个回答
0
投票
  1. ONNX 导出错误: 将尺寸重新排列为 NHWC 格式。然后,从 ONNX 导出到 TensorFlow 时,确保保留 NHWC 格式。

    nhwc_tensor = input_tensor.permute(0, 2, 3, 1)

  2. TensorFlow Lite 输入预处理: 错误输入张量的类型为 kTfLiteFloat32:它需要指定 NormalizationOptions 元数据来预处理输入图像。 (您需要指定输入图像在通过模型之前应如何标准化)。

将标准化元数据添加到您的 .tflite 模型中:

import tflite_support
from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata

NORMALIZATION_MEAN = [127.5]
NORMALIZATION_STD = [127.5]

writer = image_classifier.MetadataWriter.create_for_inference(
    writer_utils.load_file("your_model.tflite"),
    input_norm_mean=NORMALIZATION_MEAN,
    input_norm_std=NORMALIZATION_STD)

writer_utils.save_file(writer.populate(), "your_model_tflite_here")
© www.soinside.com 2019 - 2024. All rights reserved.