我有这段简单的代码(来自here),它将 TensorfFlow Lite 模型应用于输入图像(大小为 480x270),并在处理后显示结果图像。当使用
evsrnet_x4.tflite
运行它时,该逻辑有效,我可以看到输出图像。该项目资产还包括另一个模型esrgan.tflite
,它应该采用 50x50 输入图像并生成 200x200 输出图像。但是,当我更改代码以考虑该尺寸时,由于某些原因,输出图像看起来已损坏,如下所示。我不明白为什么。
这里出了什么问题?我还应该更改什么才能使其在 esrgan 上工作?
package com.example.mobedsr;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.CompatibilityList;
import org.tensorflow.lite.gpu.GpuDelegate;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
/** @brief Super Resolution Model class
* @date 23/01/27
*/
public class SRModel {
private boolean useGpu;
public Interpreter interpreter;
private Interpreter.Options options;
private GpuDelegate gpuDelegate;
private AssetManager assetManager;
private final String MODEL_NAME = "evsrnet_x4.tflite"; //I want to change to esrgan.tflite
SRModel(AssetManager assetManager, boolean useGpu) throws IOException {
interpreter = null;
gpuDelegate = null;
this.assetManager = assetManager;
this.useGpu = useGpu;
// Initialize the TF Lite interpreter
init();
}
private void init() throws IOException {
options = new Interpreter.Options();
// Set gpu delegate
if (useGpu) {
CompatibilityList compatList = new CompatibilityList();
GpuDelegate.Options delegateOptions = compatList.getBestOptionsForThisDevice();
gpuDelegate = new GpuDelegate(delegateOptions);
options.addDelegate(gpuDelegate);
}
// Set TF Lite interpreter
interpreter = new Interpreter(loadModelFile(), options);
}
/** @brief Load .tflite model file to ByteBuffer
* @date 23/01/25
*/
private ByteBuffer loadModelFile() throws IOException {
AssetFileDescriptor assetFileDescriptor = assetManager.openFd(MODEL_NAME);
FileInputStream fileInputStream = new FileInputStream(assetFileDescriptor.getFileDescriptor());
FileChannel fileChannel = fileInputStream.getChannel();
long startOffset = assetFileDescriptor.getStartOffset();
long declaredLength = assetFileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
public void run(Object a, Object b) {
interpreter.run(a, b);
}
/** @brief Prepare the input tensor from low resolution image
* @date 23/01/25
*/
public TensorImage prepareInputTensor(Bitmap bitmap_lr) {
TensorImage inputImage = TensorImage.fromBitmap(bitmap_lr);
int height = bitmap_lr.getHeight();
int width = bitmap_lr.getWidth();
ImageProcessor imageProcessor = new ImageProcessor.Builder()
.add(new ResizeOp(height, width, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
.add(new NormalizeOp(0.0f, 255.0f))
.build();
inputImage = imageProcessor.process(inputImage);
return inputImage;
}
/** @brief Prepare the output tensor for super resolution
* @date 23/01/25
*/
public TensorImage prepareOutputTensor() {
TensorImage srImage = new TensorImage(DataType.FLOAT32);
// int[] srShape = new int[]{1080, 1920, 3};
int[] srShape = new int[]{1920, 1080, 3};
srImage.load(TensorBuffer.createFixedSize(srShape, DataType.FLOAT32));
return srImage;
}
/** @brief Convert tensor to bitmap image
* @date 23/01/25
* @param outputTensor super resolutioned image
*/
public Bitmap tensorToImage(TensorImage outputTensor) {
ByteBuffer srOut = outputTensor.getBuffer();
srOut.rewind();
int height = outputTensor.getHeight();
int width = outputTensor.getWidth();
Bitmap bmpImage = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
int[] pixels = new int[width * height];
for (int i = 0; i < width * height; i++) {
int a = 0xFF;
float r = srOut.getFloat() * 255.0f;
float g = srOut.getFloat() * 255.0f;
float b = srOut.getFloat() * 255.0f;
pixels[i] = a << 24 | ((int) r << 16) | ((int) g << 8) | ((int) b);
}
bmpImage.setPixels(pixels, 0, width, 0, 0, width, height);
return bmpImage;
}
}
更新1:我做了下面答案提出的所有更改。进展顺利,但输出现在看起来像素化且呈紫色:
进行一些调整:
MODEL_NAME
变量以加载 esrgan.tflite
模型:private final String MODEL_NAME = "esrgan.tflite";
prepareInputTensor
以确保在将图像输入到模型之前将其大小调整为 50x50:public TensorImage prepareInputTensor(Bitmap bitmap_lr) {
// Resizes the input image to 50x50.
ImageProcessor imageProcessor = new ImageProcessor.Builder()
.add(new ResizeOp(50, 50, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
.add(new NormalizeOp(0.0f, 255.0f))
.build();
TensorImage inputImage = TensorImage.fromBitmap(bitmap_lr);
inputImage = imageProcessor.process(inputImage);
return inputImage;
}
prepareOutputTensor
以调整输出张量形状以匹配 200x200 的预期输出大小:public TensorImage prepareOutputTensor() {
TensorImage srImage = new TensorImage(DataType.FLOAT32);
// Adjusted to 200x200 and added batch size dimension
int[] srShape = new int[]{1, 200, 200, 3};
srImage.load(TensorBuffer.createFixedSize(srShape, DataType.FLOAT32));
return srImage;
}
tensorToImage
以确保从浮点值转换后像素值被限制在[0,255]范围内:public Bitmap tensorToImage(TensorImage outputTensor) {
ByteBuffer srOut = outputTensor.getBuffer();
srOut.rewind();
int height = outputTensor.getHeight();
int width = outputTensor.getWidth();
Bitmap bmpImage = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
int[] pixels = new int[width * height];
for (int i = 0; i < width * height; i++) {
int a = 0xFF;
int r = (int) (srOut.getFloat() * 255.0f);
int g = (int) (srOut.getFloat() * 255.0f);
int b = (int) (srOut.getFloat() * 255.0f);
// Clamps the pixel values to [0, 255] range.
r = Math.max(0, Math.min(255, r));
g = Math.max(0, Math.min(255, g));
b = Math.max(0, Math.min(255, b));
pixels[i] = a << 24 | (r << 16) | (g << 8) | b;
}
bmpImage.setPixels(pixels, 0, width, 0, 0, width, height);
return bmpImage;
}
确保调整代码以与
esrgan.tflite
模型无缝协作,确保预处理和后处理步骤符合其特定要求。如果需要,调整 NormalizeOp
中的归一化值,并仔细检查输入和输出数据类型是否与模型的规范匹配。
我希望这能解决您的问题。