在 flutter 中实现 TFlite 模型时出错

问题描述 投票:0回答:1

所以我一直在尝试在 flutter 中实现一些 tflite 模型,这些模型是我之前从 .h5 文件转换而来的。我在调试时特意使用了打印函数来了解问题所在。 出现如下消息后应用程序挂起: enter image description here 我在此之前使用的打印功能表明我的模型已成功加载 enter image description here

应用程序遇到运行时错误,指出错误的先决条件,通常指出输入大小参数(我在这里可能是错的)。

以下是我用来实现 Tflite 模型的函数。 基本上我需要运行几个模型(一次一个),在运行第一个模型后,我为其设置了一些“if”条件。

import 'package:flutter/material.dart';
import 'package:file_picker/file_picker.dart';
import 'package:permission_handler/permission_handler.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:image/image.dart' as img;
import 'dart:math';
import 'dart:io';



class ImageScreen extends StatefulWidget {
  const ImageScreen({Key? key}) : super(key: key);

  @override
  _ImageScreenState createState() => _ImageScreenState();
}
class _ImageScreenState extends State<ImageScreen> {

  var imgFile;
  String Image_Path= "";



  Future<String> pickImage() async {
      // Function body
      // Add a return statement at the end
      PermissionStatus status = await Permission.storage.request();

      FilePickerResult? result = await FilePicker.platform.pickFiles(
        type: FileType.custom,
        allowedExtensions: ['jpg', 'jpeg', 'png'],
      );

      if (result != null) {
        imgFile = result.files.single;
        print("Image file path: " + imgFile.path);
        Image_Path=imgFile.path;
        print("picked image");
        return Image_Path;
      } else {
        // User canceled the picker
        print("Error");
        return "Error";
      }
  }



//LOAD MODEL WORKS


  Interpreter? interpreter;
  Future<void> loadModel(String modelPath) async {
  try {
    interpreter = await Interpreter.fromAsset(modelPath);
    print('Loaded model successfully');

    // Get input and output shapes
    if (interpreter != null) {
      var inputShape = interpreter!.getInputTensor(0).shape;
      var outputShape = interpreter!.getOutputTensor(0).shape;

      print('Input shape: $inputShape');
      print('Output shape: $outputShape');
      
      //I/flutter (10135): Input shape: [1, 256, 256, 3]
      //I/flutter (10135): Output shape: [1, 4]

    }
  } catch (e) {
    print('Failed to load model: $e');
  }
}


Future<List> runModel(String imagePath) async {
  // Load the image
  var image = img.decodeImage(File(imagePath).readAsBytesSync());
  var resized = img.copyResize(image!, width: 256, height: 256);

  // Convert image to a list of floats
  var input = resized.getBytes().buffer.asFloat32List().reshape([49152]);

  // Create output tensor
  var output = List<double>.filled(4,0).reshape([1,4]);//try filled(4.0), reshape([1,4])

  // Check if interpreter is null
if (interpreter == null) {
  print('Interpreter is null');
  //return;
}

// Print input and output shapes
print('Input shape: ${input.shape}');
print('Output shape: ${output.shape}');

// Print input data
print('Input data: $input');

// Run the model
interpreter?.run(input, output);

  print("ok running model");
  return output;
}


Future<String> makePredictions(String imagePath) async {
  // Load and run the 4-class model
  await loadModel('assets/models/four_class_STFT_80valacc.tflite');
  var prediction_4class = await runModel(imagePath);

  String result = '';

  // Convert List<dynamic> to List<double>
  List<double> prediction_4class_double = prediction_4class.cast<double>();


  // Find the index of the maximum value in prediction_4class
  int maxIndex = prediction_4class.indexWhere((d) => d == prediction_4class_double.reduce(max));

  // Based on the result, load and run the appropriate binary model
  if (maxIndex == 0) {
    print("ok 1");
    result = 'Normal';
  } else if (maxIndex == 1) {
    // Asthma
    await loadModel('assets/models/NvsA.tflite');
    print("ok 1");
    var prediction_binary = await runModel(imagePath);
    result = 'Asthma with confidence ${prediction_binary[0]}';
  } else if (maxIndex == 2) {
    // Pneumonia
    print("ok 1");
    await loadModel('assets/models/NvsP.tflite');
    var prediction_binary = await runModel(imagePath);
    result = 'Pneumonia with confidence ${prediction_binary[0]}';
  } else if (maxIndex == 3) {
    // COPD
    print("ok 1");
    await loadModel('assets/models/NvsC_best.tflite');
    var prediction_binary = await runModel(imagePath);
    result = 'COPD with confidence ${prediction_binary[0]}';
  }

  return result;
}



  //Remove this if causing errors
 /*
  @override
  void initState() {
    super.initState();
    loadModel(Image_Path).then((value) {setState((){});});
  }*/



  @override
  Widget build(BuildContext context) {
    return Scaffold(
          backgroundColor: Colors.white,
      appBar:AppBar(
        elevation: 0,
        title:const Row(
                      mainAxisAlignment: MainAxisAlignment.center,
                      children: <Widget>[
                        Text(
                          "Ausculto",
                          style:
                          TextStyle(color: Color.fromARGB(221, 7, 173, 224), fontWeight: FontWeight.w600),
                        ),
                        Text(
                          "Wave",
                          style: TextStyle(color: Color.fromARGB(255, 248, 213, 16), fontWeight: FontWeight.w600),
                        ),
                        Text("       "),//balances out, maintains center, use better methods later
                      ],
                    ),centerTitle: true,
        backgroundColor: const Color.fromARGB(255, 255, 255, 255),
      ),

      body:Center(
        child:Container(
          width: 200,
          height: 250,
          child:ListView(
            children: [
              const Text(
                    " Results:",
                    style: TextStyle(fontSize: 24, fontWeight: FontWeight.bold,color: Colors.black),
              ),
              const SizedBox(height: 20,),
              
              //FOR TESTING ONLY: 
              
              Container(
                width:200,
                child:ElevatedButton(
                onPressed: () async{
                  //ability to add an image 
                  String? imagePath = await pickImage();
                  if (imagePath != null) {
                    // Run the model on the image
                    String result = await makePredictions(imagePath);

                  // Display the result in a dialog
                  showDialog(
                    context: context,
                    builder: (BuildContext context) {
                      return AlertDialog(
                        title: Text('Prediction'),
                        content: Text(result),
                        actions: <Widget>[
                          TextButton(
                            child: Text('Close'),
                            onPressed: () {
                              Navigator.of(context).pop();
                            },
                          ),
                        ],
                      );
                    },
                  );
                  }
                },
                //change color of elevated button here
                style: ElevatedButton.styleFrom(
                  backgroundColor: Colors.red,
                  foregroundColor: Colors.white,
                  shape: RoundedRectangleBorder(
                    borderRadius: BorderRadius.circular(18.0),
                  ),
                ),
                child: const Text('Add Test Image',style: TextStyle(fontSize: 13),),
              ),
              ),
              const SizedBox(height: 20,),
              
            ],
          ),
        ),
      ),
    );

    
  }
}

我想获得 tflite 模型的输出,模型告诉我一个人是否正常,然后进一步设置条件,如果模型预测这个人不正常并且说患有疾病,其他模型也执行相同的检查。 我花了 10 多个小时尝试不同的方法来实现 dart 文件,从不同的项目,所有这些都失败了并导致相同的结果或应用程序最终无法自行构建。 网络上的一些解决方案包含已弃用或与 dart 3.0 不兼容的软件包

android flutter tensorflow flutter-dependencies tensorflow-lite
1个回答
0
投票

所以我经过几个小时的调试才弄清楚了这一点。事实证明,labels.txt 中存在的元素数量改变了输出形状要求。

这是有效的代码:

return FutureBuilder<List<String>>(
  future: () async {
    try {
      print("Closed previous model");

      await Tflite.loadModel(
        model: 'assets/models/NvsA.tflite',
        labels: 'assets/models/labelsA.txt',
      );
      print("ModelA loaded successfully");
      var output = await Tflite.runModelOnImage(
        path: path_to_your_image, 
        numResults: 2,
        threshold:0.2,
        imageMean: 127.5,
        imageStd: 127.5,
      );
      print("Model A run successfully");
      print(output);

      String result;
      if(output?[0]['confidence'] > 0.5){
        print("Normal");
        result = "Normal";
      }
      else{
        print("Asthma");
        result = "Asthma";
      }
      //await Tflite.close(); // Close ModelA after inference
      print("ModelA closed successfully");
      return [result];
    } catch (error) {
      print(error); // Log errors for debugging
      return ["Error"];
    }
  }(),
  builder: (context, snapshot) {
    if (snapshot.connectionState == ConnectionState.waiting) {
      return const CircularProgressIndicator();
    } else if (snapshot.hasError) {
      return Text('Error: ${snapshot.error}');
    } else {
      return Text('Result: ${snapshot.data![0]}');
    }
  }

);

© www.soinside.com 2019 - 2024. All rights reserved.