在 Android 中使用 Tensorflow Lite 时输出图像看起来损坏

问题描述 投票:0回答:1

我有这段简单的代码(来自here),它将 TensorfFlow Lite 模型应用于输入图像(大小为 480x270),并在处理后显示结果图像。当使用

evsrnet_x4.tflite
运行它时,该逻辑有效,我可以看到输出图像。该项目资产还包括另一个模型
esrgan.tflite
,它应该采用 50x50 输入图像并生成 200x200 输出图像。但是,当我更改代码以考虑该尺寸时,由于某些原因,输出图像看起来已损坏,如下所示。我不明白为什么。

这里出了什么问题?我还应该更改什么才能使其在 esrgan 上工作?

package com.example.mobedsr;

import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.CompatibilityList;
import org.tensorflow.lite.gpu.GpuDelegate;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;


/** @brief  Super Resolution Model class
 *  @date   23/01/27
 */
public class SRModel {
    private boolean useGpu;

    public Interpreter interpreter;
    private Interpreter.Options options;
    private GpuDelegate gpuDelegate;
    private AssetManager assetManager;

    private final String MODEL_NAME = "evsrnet_x4.tflite"; //I want to change to esrgan.tflite

    SRModel(AssetManager assetManager, boolean useGpu) throws IOException {
        interpreter = null;
        gpuDelegate = null;

        this.assetManager = assetManager;
        this.useGpu = useGpu;

        // Initialize the TF Lite interpreter
        init();
    }

    private void init() throws IOException {
        options = new Interpreter.Options();

        // Set gpu delegate
        if (useGpu) {
            CompatibilityList compatList = new CompatibilityList();
            GpuDelegate.Options delegateOptions = compatList.getBestOptionsForThisDevice();
            gpuDelegate = new GpuDelegate(delegateOptions);
            options.addDelegate(gpuDelegate);
        }

        // Set TF Lite interpreter
        interpreter = new Interpreter(loadModelFile(), options);
    }

    /** @brief  Load .tflite model file to ByteBuffer
     *  @date   23/01/25
     */
    private ByteBuffer loadModelFile() throws IOException {
        AssetFileDescriptor assetFileDescriptor = assetManager.openFd(MODEL_NAME);
        FileInputStream fileInputStream = new FileInputStream(assetFileDescriptor.getFileDescriptor());

        FileChannel fileChannel = fileInputStream.getChannel();
        long startOffset = assetFileDescriptor.getStartOffset();
        long declaredLength = assetFileDescriptor.getDeclaredLength();

        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    public void run(Object a, Object b) {
        interpreter.run(a, b);
    }


    /** @brief  Prepare the input tensor from low resolution image
     *  @date   23/01/25
     */
    public TensorImage prepareInputTensor(Bitmap bitmap_lr) {
        TensorImage inputImage = TensorImage.fromBitmap(bitmap_lr);
        int height = bitmap_lr.getHeight();
        int width = bitmap_lr.getWidth();

        ImageProcessor imageProcessor = new ImageProcessor.Builder()
                .add(new ResizeOp(height, width, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                .add(new NormalizeOp(0.0f, 255.0f))
                .build();
        inputImage = imageProcessor.process(inputImage);

        return inputImage;
    }


    /** @brief  Prepare the output tensor for super resolution
     *  @date   23/01/25
     */
    public TensorImage prepareOutputTensor() {
        TensorImage srImage = new TensorImage(DataType.FLOAT32);
//        int[] srShape = new int[]{1080, 1920, 3};
        int[] srShape = new int[]{1920, 1080, 3};
        srImage.load(TensorBuffer.createFixedSize(srShape, DataType.FLOAT32));

        return srImage;
    }


    /** @brief  Convert tensor to bitmap image
     *  @date   23/01/25
     *  @param outputTensor super resolutioned image
     */
    public Bitmap tensorToImage(TensorImage outputTensor) {
        ByteBuffer srOut = outputTensor.getBuffer();
        srOut.rewind();

        int height = outputTensor.getHeight();
        int width = outputTensor.getWidth();

        Bitmap bmpImage = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
        int[] pixels = new int[width * height];

        for (int i = 0; i < width * height; i++) {
            int a = 0xFF;
            float r = srOut.getFloat() * 255.0f;
            float g = srOut.getFloat() * 255.0f;
            float b = srOut.getFloat() * 255.0f;

            pixels[i] = a << 24 | ((int) r << 16) | ((int) g << 8) | ((int) b);
        }

        bmpImage.setPixels(pixels, 0, width, 0, 0, width, height);

        return bmpImage;
    }


}

更新1:我做了下面答案提出的所有更改。进展顺利,但输出现在看起来像素化且呈紫色:

java android image-processing deep-learning tensorflow-lite
1个回答
0
投票

Java Android 图像处理

进行一些调整:

  1. 调整
    MODEL_NAME
    变量以加载
    esrgan.tflite
    模型:
private final String MODEL_NAME = "esrgan.tflite";
  1. 此外,请记住“esrgan.tflite”模型需要 50x50 输入图像并生成 200x200 输出图像,因此您需要相应地调整准备输入和输出张量的方式:
  • 修改
    prepareInputTensor
    以确保在将图像输入到模型之前将其大小调整为 50x50:
public TensorImage prepareInputTensor(Bitmap bitmap_lr) {
    // Resizes the input image to 50x50.
    ImageProcessor imageProcessor = new ImageProcessor.Builder()
            .add(new ResizeOp(50, 50, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
            .add(new NormalizeOp(0.0f, 255.0f))
            .build();
    TensorImage inputImage = TensorImage.fromBitmap(bitmap_lr);
    inputImage = imageProcessor.process(inputImage);

    return inputImage;
}
  • 修改
    prepareOutputTensor
    以调整输出张量形状以匹配 200x200 的预期输出大小:
public TensorImage prepareOutputTensor() {
    TensorImage srImage = new TensorImage(DataType.FLOAT32);

    // Adjusted to 200x200 and added batch size dimension
    int[] srShape = new int[]{1, 200, 200, 3};  
    srImage.load(TensorBuffer.createFixedSize(srShape, DataType.FLOAT32));

    return srImage;
}
  • 修改
    tensorToImage
    以确保从浮点值转换后像素值被限制在[0,255]范围内:
public Bitmap tensorToImage(TensorImage outputTensor) {
    ByteBuffer srOut = outputTensor.getBuffer();
    srOut.rewind();

    int height = outputTensor.getHeight();
    int width = outputTensor.getWidth();

    Bitmap bmpImage = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
    int[] pixels = new int[width * height];

    for (int i = 0; i < width * height; i++) {
        int a = 0xFF;
        int r = (int) (srOut.getFloat() * 255.0f);
        int g = (int) (srOut.getFloat() * 255.0f);
        int b = (int) (srOut.getFloat() * 255.0f);

        // Clamps the pixel values to [0, 255] range.
        r = Math.max(0, Math.min(255, r));
        g = Math.max(0, Math.min(255, g));
        b = Math.max(0, Math.min(255, b));

        pixels[i] = a << 24 | (r << 16) | (g << 8) | b;
    }

    bmpImage.setPixels(pixels, 0, width, 0, 0, width, height);

    return bmpImage;
}

确保调整代码以与

esrgan.tflite
模型无缝协作,确保预处理和后处理步骤符合其特定要求。如果需要,调整
NormalizeOp
中的归一化值,并仔细检查输入和输出数据类型是否与模型的规范匹配。

我希望这能解决您的问题。

© www.soinside.com 2019 - 2024. All rights reserved.