我使用包含 26 个类的自定义数据集训练了一个模型 yolov8,但是当我将模型转换为 tflite 时,我注意到它给出了输出 [1,30,8400],这就是在使用我的模型与 flutter 时导致我出错的原因。
错误
E/AndroidRuntime(18479): Caused by: java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (Identity) with shape [1, 30, 8400] to a Java object with shape [1, 26].
如何修改模型的输出形状?
这就是我的模型的训练方式:
from ultralytics import YOLO
model = YOLO('yolov8s.pt')
results = model.train(data='/kaggle/input/my-
dataset/my_dataset/data.yaml', epochs=100, imgsz=640)
这是文件 data.yaml 的内容
train: /kaggle/input/dataset-asl/ASL/train/images
val: /kaggle/input/dataset-asl/ASL/valid/images
nc: 26
names: ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
'Y', 'Z']
这就是将我的模型转换为 tflite 格式的方法:
from ultralytics import YOLO
# Load a model
model = YOLO('best.pt')
# Export the model
model.export(format='tflite')
这是转换后的输出:
Ultralytics YOLOv8.2.4 Python-3.9.0 torch-2.2.1+cpu CPU (Intel
Core(TM) i5-7300U 2.60GHz)
Model summary (fused): 168 layers, 3010718 parameters, 0
gradients, 8.1 GFLOPs
PyTorch: starting from 'best.pt' with input shape (1, 3, 640, 640)
BCHW and output shape(s) (1, 30, 8400) (6.0 MB)
TensorFlow SavedModel: starting export with tensorflow 2.15.0...
WARNING tensorflow<=2.13.1 is required, but tensorflow==2.15.0 is
currently installed
https://github.com/ultralytics/ultralytics/issues/5161
ONNX: starting export with onnx 1.15.0 opset 17...
ONNX: simplifying with onnxsim 0.4.36...
ONNX: export success 1.8s, saved as 'best.onnx' (11.7 MB)
TensorFlow SavedModel: starting TFLite export with onnx2tf
1.17.5...
TensorFlow SavedModel: export success 14.3s, saved as
'best_saved_model' (29.5 MB)
TensorFlow Lite: starting export with tensorflow 2.15.0...
TensorFlow Lite: export success 0.0s, saved as
'best_saved_model\best_float32.tflite' (11.7 MB)
Export complete (16.8s)
Results saved to C:\Users\Bachir\Desktop\api\tflite
Predict: yolo predict task=detect
model=best_saved_model\best_float32.tflite imgsz=640
Validate: yolo val task=detect
model=best_saved_model\best_float32.tflite imgsz=640
data=/kaggle/input/my-dataset/my_dataset/data.yaml
Visualize: https://netron.app
'best_saved_model\\best_float32.tflite'
emphasized text
您的模型输出形状是正确的。输出形状中的 30 表示 [4 个边界框位置 + 26 个类别分数]。 8400表示有8400个可能的边界框。 Tensorflow lite 不具有非最大值抑制功能。因此,您需要组合其中一些框才能获得最佳结果。