如何修改 TFLite 导出的 Yolov5s 模型的输出以与使用 kotlin 构建的 Android 应用程序配合使用?

问题描述 投票:0回答:1

我想将我的 TFLite 导出的 Yolov5s 模型加载到我的对象检测 Android 应用程序中。我遵循了本教程:https://www.youtube.com/watch?v=zs43IrWTzB0

但是,我的 TFLite Yolov5 模型输出形状为 [1, 25200, 9] 的数组。

同时,预期的输出签名是以下 4 个数组:Detection_boxes、Detection_Classes、Detection_Scores 和 num_Detections。根据https://www.tensorflow.org/lite/examples/object_detection/overview#output_signature

我应该如何修改我的代码以使其可以在此应用程序中加载?

这是我的 TF lite 模型提供的示例代码:

val model = BestFp16.newInstance(context)

// Creates inputs for reference.
val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 640, 640, 3), DataType.FLOAT32)
inputFeature0.loadBuffer(byteBuffer)

// Runs model inference and gets result.
val outputs = model.process(inputFeature0)
val outputFeature0 = outputs.outputFeature0AsTensorBuffer

// Releases model resources if no longer used.
model.close()

这是我的 MainActivity.kt (注释部分是我尝试绘制边框的地方,但最终在屏幕的左上角绘制了静态边框,并且应用程序在几秒钟后崩溃):

package com.example.sightfulkotlin

import  android.annotation.SuppressLint
import android.content.Context
import android.content.pm.PackageManager
import android.graphics.*
import android.hardware.camera2.CameraCaptureSession
import android.hardware.camera2.CameraDevice
import android.hardware.camera2.CameraManager
import android.os.Bundle
import android.os.Handler
import android.os.HandlerThread
import android.view.Surface
import android.view.TextureView
import android.widget.ImageView
import androidx.appcompat.app.AppCompatActivity
import androidx.core.content.ContextCompat
import com.example.sightfulkotlin.ml.BestFp16
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer



class MainActivity : AppCompatActivity() {

    var colors = listOf(
        Color.BLUE, Color.GREEN, Color.RED, Color.CYAN, Color.GRAY, Color.BLACK, Color.DKGRAY, Color.MAGENTA, Color.YELLOW, Color.LTGRAY, Color.WHITE)
    val paint = Paint()
    private lateinit var labels:List<String>
    lateinit var bitmap: Bitmap
    lateinit var imageView: ImageView
    lateinit var cameraDevice: CameraDevice
    lateinit var handler: Handler
    private lateinit var cameraManager: CameraManager
    lateinit var textureView: TextureView
    lateinit var model: BestFp16

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        getPermission()

        labels = FileUtil.loadLabels(this, "labels.txt")
        model = BestFp16.newInstance(this)

        var imageProcessor = ImageProcessor.Builder().add(ResizeOp(640, 640, ResizeOp.ResizeMethod.BILINEAR)).build()

        val handlerThread = HandlerThread("videoThread")
        handlerThread.start()
        handler = Handler(handlerThread.looper)

        paint.color = Color.GREEN

        imageView = findViewById(R.id.imageView)
        textureView = findViewById(R.id.textureView)
        textureView.surfaceTextureListener = object: TextureView.SurfaceTextureListener
        {
            override fun onSurfaceTextureAvailable(p0: SurfaceTexture, p1: Int, p2: Int) {
                openCamera()
            }

            override fun onSurfaceTextureSizeChanged(p0: SurfaceTexture, p1: Int, p2: Int) {
            }

            override fun onSurfaceTextureDestroyed(p0: SurfaceTexture): Boolean {
                return false
            }

            override fun onSurfaceTextureUpdated(p0: SurfaceTexture) {
                bitmap = textureView.bitmap!!

                var tensorImage = TensorImage(DataType.FLOAT32)
                tensorImage.load(bitmap)
                tensorImage = imageProcessor.process(tensorImage)

                val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 640, 640, 3), DataType.FLOAT32)
                inputFeature0.loadBuffer(tensorImage.buffer)

                val outputs = model.process(inputFeature0)
                val outputFeature0 = outputs.outputFeature0AsTensorBuffer

                val mutable = bitmap.copy(Bitmap.Config.ARGB_8888, true)
                val canvas = Canvas(mutable)

                val h = bitmap.height
                val w = bitmap.width
                paint.textSize = h/15f
                paint.strokeWidth = h/85f

                /*val detection = outputFeature0.floatArray.copyOfRange(0, 9)

                  val numDetections = outputFeature0.shape[1]

                   for (i in 0 until numDetections) {
                       i * 9
                       val xCenter = detection[0]
                       val yCenter = detection[1]
                       val width = detection[2]
                       val height = detection[3]

                       val left = (xCenter - width/2) * w
                       val top = (yCenter - height/2) * h
                       val right = (xCenter + width/2) * w
                       val bottom = (yCenter + height/2) * h

                       canvas.drawRect(left, top, right, bottom, paint)
                   }*/

                imageView.setImageBitmap(mutable)

            }
        }

        cameraManager =  getSystemService(Context.CAMERA_SERVICE) as CameraManager
    }

    override fun onDestroy() {
        super.onDestroy()
        model.close()
    }

    @SuppressLint("MissingPermission")
    fun openCamera()
    {
        cameraManager.openCamera(cameraManager.cameraIdList[0], object: CameraDevice.StateCallback(){
            @SuppressLint("MissingPermission")
            override fun onOpened(p0: CameraDevice) {
                cameraDevice = p0

                var surfaceTexture = textureView.surfaceTexture
                var surface  = Surface(surfaceTexture)
                var captureRequest = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW)
                captureRequest.addTarget(surface)

                cameraDevice.createCaptureSession(listOf(surface), object: CameraCaptureSession.StateCallback(){
                    override fun onConfigured(p0: CameraCaptureSession) {
                        p0.setRepeatingRequest(captureRequest.build(), null, null)
                    }

                    override fun onConfigureFailed(p0: CameraCaptureSession) {
                    }
                }, handler)
            }

            override fun onDisconnected(p0: CameraDevice) {
            }

            @SuppressLint("MissingPermission")
            override fun onError(p0: CameraDevice, p1: Int) {
            }
        },handler)
    }

    fun getPermission()
    {
        if(ContextCompat.checkSelfPermission(this, android.Manifest.permission.CAMERA)!=PackageManager.PERMISSION_GRANTED)
        {
            requestPermissions(arrayOf(android.Manifest.permission.CAMERA), 101)
        }
    }

    override fun onRequestPermissionsResult(
        requestCode: Int,
        permissions: Array<out String>,
        grantResults: IntArray
    ) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults)
        if (grantResults[0] != PackageManager.PERMISSION_GRANTED)
        {
            getPermission()
        }
    }
}
kotlin object-detection tensorflow-lite object-detection-api yolov5
1个回答
0
投票

您的模型应该包含包含类信息的元数据。 当前的ultralytics yolov5 Github存储库不支持这种对象检测模型的转换,该模型可以添加元数据并稍后在android上使用。

原因是 YOLOv5 导出的模型通常将输出连接成单个输出。 TFLite 模型不使用 NMS 导出,只有 TF.js 和管道 CoreML 模型包含 NMS。 资料取自这篇文章,这个问题是有解决办法的。 你可以尝试这个选项,但它对我不起作用。 也可能的解决方案:

  1. 您也可以尝试不添加元数据,而是使用输出张量的解压缩。代码是这里
  2. 用Java重写代码,有机会加载没有元数据的模型;
  3. 训练另一个支持元数据记录的模型
© www.soinside.com 2019 - 2024. All rights reserved.