我有一个对我来说有点复杂的问题,那就是我使用tensorflow在flutter中实时检测对象(live detector)
我的模型输出有问题(tflite我在我的项目中使用)
上训练图像时当我导出模型并将其添加到我的项目中时,出现此问题
E/flutter ( 8569): Invalid argument(s): Output object shape mismatch, interpreter returned output of shape: [1, 2] while shape of output provided as argument in run is: [1, 10, 4]
所以检测代码是:
import 'dart:async';
import 'dart:io';
import 'dart:isolate';
import 'package:camera/camera.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'package:image/image.dart' as image_lib;
import 'package:live_object_detection_ssd_mobilenet/models/recognition.dart';
import 'package:live_object_detection_ssd_mobilenet/utils/image_utils.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
enum _Codes {
init,
busy,
ready,
detect,
result,
}
class _Command {
const _Command(this.code, {this.args});
final _Codes code;
final List<Object>? args;
}
class Detector {
static const String _modelPath = 'assets/models/ssd_mobilenet.tflite';
static const String _labelPath = 'assets/models/labelmap.txt';
Detector._(this._isolate, this._interpreter, this._labels);
final Isolate _isolate;
late final Interpreter _interpreter;
late final List<String> _labels;
late final SendPort _sendPort;
bool _isReady = false;
final StreamController<Map<String, dynamic>> resultsStream =
StreamController<Map<String, dynamic>>();
static Future<Detector> start() async {
final ReceivePort receivePort = ReceivePort();
final Isolate isolate =
await Isolate.spawn(_DetectorServer._run, receivePort.sendPort);
final Detector result = Detector._(
isolate,
await _loadModel(),
await _loadLabels(),
);
receivePort.listen((message) {
result._handleCommand(message as _Command);
});
return result;
}
static Future<Interpreter> _loadModel() async {
final interpreterOptions = InterpreterOptions();
if (Platform.isAndroid) {
interpreterOptions.addDelegate(XNNPackDelegate());
}
return Interpreter.fromAsset(
_modelPath,
options: interpreterOptions..threads = 4,
);
}
static Future<List<String>> _loadLabels() async {
return (await rootBundle.loadString(_labelPath)).split('\n');
}
void processFrame(CameraImage cameraImage) {
if (_isReady) {
_sendPort.send(_Command(_Codes.detect, args: [cameraImage]));
}
}
void _handleCommand(_Command command) {
switch (command.code) {
case _Codes.init:
_sendPort = command.args?[0] as SendPort;
RootIsolateToken rootIsolateToken = RootIsolateToken.instance!;
_sendPort.send(_Command(_Codes.init, args: [
rootIsolateToken,
_interpreter.address,
_labels,
]));
case _Codes.ready:
_isReady = true;
case _Codes.busy:
_isReady = false;
case _Codes.result:
_isReady = true;
resultsStream.add(command.args?[0] as Map<String, dynamic>);
default:
debugPrint('Detector unrecognized command: ${command.code}');
}
}
void stop() {
_isolate.kill();
}
}
class _DetectorServer {
static const int mlModelInputSize = 300;
static const double confidence = 0.5;
Interpreter? _interpreter;
List<String>? _labels;
_DetectorServer(this._sendPort);
final SendPort _sendPort;
static void _run(SendPort sendPort) {
ReceivePort receivePort = ReceivePort();
final _DetectorServer server = _DetectorServer(sendPort);
receivePort.listen((message) async {
final _Command command = message as _Command;
await server._handleCommand(command);
});
sendPort.send(_Command(_Codes.init, args: [receivePort.sendPort]));
}
Future<void> _handleCommand(_Command command) async {
switch (command.code) {
case _Codes.init:
RootIsolateToken rootIsolateToken =
command.args?[0] as RootIsolateToken;
BackgroundIsolateBinaryMessenger.ensureInitialized(rootIsolateToken);
_interpreter = Interpreter.fromAddress(command.args?[1] as int);
_labels = command.args?[2] as List<String>;
_sendPort.send(const _Command(_Codes.ready));
case _Codes.detect:
_sendPort.send(const _Command(_Codes.busy));
_convertCameraImage(command.args?[0] as CameraImage);
default:
debugPrint('_DetectorService unrecognized command ${command.code}');
}
}
void _convertCameraImage(CameraImage cameraImage) {
var preConversionTime = DateTime.now().millisecondsSinceEpoch;
convertCameraImageToImage(cameraImage).then((image) {
if (image != null) {
if (Platform.isAndroid) {
image = image_lib.copyRotate(image, angle: 90);
}
final results = analyseImage(image, preConversionTime);
_sendPort.send(_Command(_Codes.result, args: [results]));
}
});
}
Map<String, dynamic> analyseImage(
image_lib.Image? image, int preConversionTime) {
var conversionElapsedTime =
DateTime.now().millisecondsSinceEpoch - preConversionTime;
var preProcessStart = DateTime.now().millisecondsSinceEpoch;
final imageInput = image_lib.copyResize(
image!,
width: mlModelInputSize,
height: mlModelInputSize,
);
final imageMatrix = List.generate(
imageInput.height,
(y) => List.generate(
imageInput.width,
(x) {
final pixel = imageInput.getPixel(x, y);
return [pixel.r, pixel.g, pixel.b];
},
),
);
var preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
final output = _runInference(imageMatrix);
final locationsRaw = output.first.first as List<List<double>>;
final List<Rect> locations = locationsRaw
.map((list) => list.map((value) => (value * mlModelInputSize)).toList())
.map((rect) => Rect.fromLTRB(rect[1], rect[0], rect[3], rect[2]))
.toList();
final classesRaw = output.elementAt(1).first as List<double>;
final classes = classesRaw.map((value) => value.toInt()).toList();
final scores = output.elementAt(2).first as List<double>;
final numberOfDetectionsRaw = output.last.first as double;
final numberOfDetections = numberOfDetectionsRaw.toInt();
final List<String> classification = [];
for (var i = 0; i < numberOfDetections; i++) {
classification.add(_labels![classes[i]]);
}
List<Recognition> recognitions = [];
for (int i = 0; i < numberOfDetections; i++) {
var score = scores[i];
var label = classification[i];
if (score > confidence) {
recognitions.add(
Recognition(i, label, score, locations[i]),
);
}
}
var inferenceElapsedTime =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
var totalElapsedTime =
DateTime.now().millisecondsSinceEpoch - preConversionTime;
return {
"recognitions": recognitions,
"stats": <String, String>{
'Conversion time:': conversionElapsedTime.toString(),
'Pre-processing time:': preProcessElapsedTime.toString(),
'Inference time:': inferenceElapsedTime.toString(),
'Total prediction time:': totalElapsedTime.toString(),
'Frame': '${image.width} X ${image.height}',
},
};
}
List<List<Object>> _runInference(
List<List<List<num>>> imageMatrix,
) {
final input = [imageMatrix];
final output = {
0: [List<List<num>>.filled(10, List<num>.filled(4, 0))],
1: [List<num>.filled(10, 0)],
2: [List<num>.filled(10, 0)],
3: [0.0],
};
_interpreter!.runForMultipleInputs([input], output);
return output.values.toList();
}
}
如果您使用可教机器进行了训练,那么它很可能不是对象检测模型。它必须是图像分类模型。这解释了输出形状[1,2]。更改输出以匹配 [1,2] 并打印结果。这将使您对输出有一个了解。
var output = {0:[List<num>.filled(2,0)]};