无法从形状为 [1, 30, 8400] 的 TensorFlowLite 张量(Identity)复制到形状为 [1, 26] 的 Java 对象

问题描述 投票:0回答:1

我使用包含 26 个类的自定义数据集训练了一个模型 yolov8,但是当我将模型转换为 tflite 时,我注意到它给出了输出 [1,30,8400],这就是在使用我的模型与 flutter 时导致我出错的原因。

错误

E/AndroidRuntime(18479): Caused by: java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (Identity) with shape [1, 30, 8400] to a Java object with shape [1, 26].

如何修改模型的输出形状?

这就是我的模型的训练方式:

 from ultralytics import YOLO

 model = YOLO('yolov8s.pt')

 results = model.train(data='/kaggle/input/my- 
 dataset/my_dataset/data.yaml', epochs=100, imgsz=640)

这是文件 data.yaml 的内容

train: /kaggle/input/dataset-asl/ASL/train/images
val: /kaggle/input/dataset-asl/ASL/valid/images

nc: 26
names: ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 
'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 
'Y', 'Z']

这就是将我的模型转换为 tflite 格式的方法:

from ultralytics import YOLO

# Load a model
model = YOLO('best.pt')

# Export the model
model.export(format='tflite')

这是转换后的输出:

Ultralytics YOLOv8.2.4  Python-3.9.0 torch-2.2.1+cpu CPU (Intel 
Core(TM) i5-7300U 2.60GHz)
Model summary (fused): 168 layers, 3010718 parameters, 0 
gradients, 8.1 GFLOPs

PyTorch: starting from 'best.pt' with input shape (1, 3, 640, 640) 
BCHW and output shape(s) (1, 30, 8400) (6.0 MB)

TensorFlow SavedModel: starting export with tensorflow 2.15.0...
WARNING  tensorflow<=2.13.1 is required, but tensorflow==2.15.0 is 
currently installed 
https://github.com/ultralytics/ultralytics/issues/5161

ONNX: starting export with onnx 1.15.0 opset 17...
ONNX: simplifying with onnxsim 0.4.36...
ONNX: export success  1.8s, saved as 'best.onnx' (11.7 MB)
TensorFlow SavedModel: starting TFLite export with onnx2tf 
1.17.5...
TensorFlow SavedModel: export success  14.3s, saved as 
'best_saved_model' (29.5 MB)

TensorFlow Lite: starting export with tensorflow 2.15.0...
TensorFlow Lite: export success  0.0s, saved as 
'best_saved_model\best_float32.tflite' (11.7 MB)

Export complete (16.8s)
Results saved to C:\Users\Bachir\Desktop\api\tflite
Predict:         yolo predict task=detect 
model=best_saved_model\best_float32.tflite imgsz=640  
Validate:        yolo val task=detect 
model=best_saved_model\best_float32.tflite imgsz=640 
data=/kaggle/input/my-dataset/my_dataset/data.yaml  
Visualize:       https://netron.app
'best_saved_model\\best_float32.tflite'
emphasized text
flutter tensorflow-lite yolov8 tflite
1个回答
0
投票

您的模型输出形状是正确的。输出形状中的 30 表示 [4 个边界框位置 + 26 个类别分数]。 8400表示有8400个可能的边界框。 Tensorflow lite 不具有非最大值抑制功能。因此,您需要组合其中一些框才能获得最佳结果。

© www.soinside.com 2019 - 2024. All rights reserved.