所以我一直在尝试在 flutter 中实现一些 tflite 模型,这些模型是我之前从 .h5 文件转换而来的。我在调试时特意使用了打印函数来了解问题所在。 出现如下消息后应用程序挂起: 我在此之前使用的打印功能表明我的模型已成功加载
应用程序遇到运行时错误,指出错误的先决条件,通常指出输入大小参数(我在这里可能是错的)。
以下是我用来实现 Tflite 模型的函数。 基本上我需要运行几个模型(一次一个),在运行第一个模型后,我为其设置了一些“if”条件。
import 'package:flutter/material.dart';
import 'package:file_picker/file_picker.dart';
import 'package:permission_handler/permission_handler.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:image/image.dart' as img;
import 'dart:math';
import 'dart:io';
class ImageScreen extends StatefulWidget {
const ImageScreen({Key? key}) : super(key: key);
@override
_ImageScreenState createState() => _ImageScreenState();
}
class _ImageScreenState extends State<ImageScreen> {
var imgFile;
String Image_Path= "";
Future<String> pickImage() async {
// Function body
// Add a return statement at the end
PermissionStatus status = await Permission.storage.request();
FilePickerResult? result = await FilePicker.platform.pickFiles(
type: FileType.custom,
allowedExtensions: ['jpg', 'jpeg', 'png'],
);
if (result != null) {
imgFile = result.files.single;
print("Image file path: " + imgFile.path);
Image_Path=imgFile.path;
print("picked image");
return Image_Path;
} else {
// User canceled the picker
print("Error");
return "Error";
}
}
//LOAD MODEL WORKS
Interpreter? interpreter;
Future<void> loadModel(String modelPath) async {
try {
interpreter = await Interpreter.fromAsset(modelPath);
print('Loaded model successfully');
// Get input and output shapes
if (interpreter != null) {
var inputShape = interpreter!.getInputTensor(0).shape;
var outputShape = interpreter!.getOutputTensor(0).shape;
print('Input shape: $inputShape');
print('Output shape: $outputShape');
//I/flutter (10135): Input shape: [1, 256, 256, 3]
//I/flutter (10135): Output shape: [1, 4]
}
} catch (e) {
print('Failed to load model: $e');
}
}
Future<List> runModel(String imagePath) async {
// Load the image
var image = img.decodeImage(File(imagePath).readAsBytesSync());
var resized = img.copyResize(image!, width: 256, height: 256);
// Convert image to a list of floats
var input = resized.getBytes().buffer.asFloat32List().reshape([49152]);
// Create output tensor
var output = List<double>.filled(4,0).reshape([1,4]);//try filled(4.0), reshape([1,4])
// Check if interpreter is null
if (interpreter == null) {
print('Interpreter is null');
//return;
}
// Print input and output shapes
print('Input shape: ${input.shape}');
print('Output shape: ${output.shape}');
// Print input data
print('Input data: $input');
// Run the model
interpreter?.run(input, output);
print("ok running model");
return output;
}
Future<String> makePredictions(String imagePath) async {
// Load and run the 4-class model
await loadModel('assets/models/four_class_STFT_80valacc.tflite');
var prediction_4class = await runModel(imagePath);
String result = '';
// Convert List<dynamic> to List<double>
List<double> prediction_4class_double = prediction_4class.cast<double>();
// Find the index of the maximum value in prediction_4class
int maxIndex = prediction_4class.indexWhere((d) => d == prediction_4class_double.reduce(max));
// Based on the result, load and run the appropriate binary model
if (maxIndex == 0) {
print("ok 1");
result = 'Normal';
} else if (maxIndex == 1) {
// Asthma
await loadModel('assets/models/NvsA.tflite');
print("ok 1");
var prediction_binary = await runModel(imagePath);
result = 'Asthma with confidence ${prediction_binary[0]}';
} else if (maxIndex == 2) {
// Pneumonia
print("ok 1");
await loadModel('assets/models/NvsP.tflite');
var prediction_binary = await runModel(imagePath);
result = 'Pneumonia with confidence ${prediction_binary[0]}';
} else if (maxIndex == 3) {
// COPD
print("ok 1");
await loadModel('assets/models/NvsC_best.tflite');
var prediction_binary = await runModel(imagePath);
result = 'COPD with confidence ${prediction_binary[0]}';
}
return result;
}
//Remove this if causing errors
/*
@override
void initState() {
super.initState();
loadModel(Image_Path).then((value) {setState((){});});
}*/
@override
Widget build(BuildContext context) {
return Scaffold(
backgroundColor: Colors.white,
appBar:AppBar(
elevation: 0,
title:const Row(
mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
Text(
"Ausculto",
style:
TextStyle(color: Color.fromARGB(221, 7, 173, 224), fontWeight: FontWeight.w600),
),
Text(
"Wave",
style: TextStyle(color: Color.fromARGB(255, 248, 213, 16), fontWeight: FontWeight.w600),
),
Text(" "),//balances out, maintains center, use better methods later
],
),centerTitle: true,
backgroundColor: const Color.fromARGB(255, 255, 255, 255),
),
body:Center(
child:Container(
width: 200,
height: 250,
child:ListView(
children: [
const Text(
" Results:",
style: TextStyle(fontSize: 24, fontWeight: FontWeight.bold,color: Colors.black),
),
const SizedBox(height: 20,),
//FOR TESTING ONLY:
Container(
width:200,
child:ElevatedButton(
onPressed: () async{
//ability to add an image
String? imagePath = await pickImage();
if (imagePath != null) {
// Run the model on the image
String result = await makePredictions(imagePath);
// Display the result in a dialog
showDialog(
context: context,
builder: (BuildContext context) {
return AlertDialog(
title: Text('Prediction'),
content: Text(result),
actions: <Widget>[
TextButton(
child: Text('Close'),
onPressed: () {
Navigator.of(context).pop();
},
),
],
);
},
);
}
},
//change color of elevated button here
style: ElevatedButton.styleFrom(
backgroundColor: Colors.red,
foregroundColor: Colors.white,
shape: RoundedRectangleBorder(
borderRadius: BorderRadius.circular(18.0),
),
),
child: const Text('Add Test Image',style: TextStyle(fontSize: 13),),
),
),
const SizedBox(height: 20,),
],
),
),
),
);
}
}
我想获得 tflite 模型的输出,模型告诉我一个人是否正常,然后进一步设置条件,如果模型预测这个人不正常并且说患有疾病,其他模型也执行相同的检查。 我花了 10 多个小时尝试不同的方法来实现 dart 文件,从不同的项目,所有这些都失败了并导致相同的结果或应用程序最终无法自行构建。 网络上的一些解决方案包含已弃用或与 dart 3.0 不兼容的软件包
所以我经过几个小时的调试才弄清楚了这一点。事实证明,labels.txt 中存在的元素数量改变了输出形状要求。
这是有效的代码:
return FutureBuilder<List<String>>(
future: () async {
try {
print("Closed previous model");
await Tflite.loadModel(
model: 'assets/models/NvsA.tflite',
labels: 'assets/models/labelsA.txt',
);
print("ModelA loaded successfully");
var output = await Tflite.runModelOnImage(
path: path_to_your_image,
numResults: 2,
threshold:0.2,
imageMean: 127.5,
imageStd: 127.5,
);
print("Model A run successfully");
print(output);
String result;
if(output?[0]['confidence'] > 0.5){
print("Normal");
result = "Normal";
}
else{
print("Asthma");
result = "Asthma";
}
//await Tflite.close(); // Close ModelA after inference
print("ModelA closed successfully");
return [result];
} catch (error) {
print(error); // Log errors for debugging
return ["Error"];
}
}(),
builder: (context, snapshot) {
if (snapshot.connectionState == ConnectionState.waiting) {
return const CircularProgressIndicator();
} else if (snapshot.hasError) {
return Text('Error: ${snapshot.error}');
} else {
return Text('Result: ${snapshot.data![0]}');
}
}
);