Android上的Tensorflow推断产生垃圾结果

问题描述 投票:0回答:1

我使用tensorflow训练了一个模型,然后将其转换为tensorflow-lite格式。

模型推断在使用python的笔记本电脑上绝对可以正常工作。然后我将模型放入Android应用程序,并使用tensorflowlite解释器进行推断,结果只不过是一张全黑图像。我将python中的代码原样移植到Java,仍然得到了这个垃圾结果。

任何想法我可能在这里出错。

Python代码:

def preprocess(img):
    return (img / 255. - 0.5) * 2

def deprocess(img):
    return (img + 1) / 2

img_size = 256

frozen_model_filename = os.path.join('model/tflite', 'model.tflite')

image_1 = cv2.resize(imread(image_1), (img_size, img_size))
X_1 = np.expand_dims(preprocess(image_1), 0)
X_1 = X_1.astype(np.float32)

image_2 = cv2.resize(imread(image_2), (img_size, img_size))
X_2 = np.expand_dims(preprocess(image_2), 0)
X_2 = X_2.astype(np.float32)


interpreter = tf.lite.Interpreter(model_path=frozen_model_filename)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


interpreter.set_tensor(input_details[0]['index'], X_1)
interpreter.set_tensor(input_details[1]['index'], X_2)
interpreter.invoke()

Output = interpreter.get_tensor(output_details[0]['index'])
Output = deprocess(Output)

imsave('result_tflite.jpg', Output[0])

Android平台的相应Java代码:

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {

    Bitmap resized = Bitmap.createScaledBitmap(bitmap, IMAGE_SIZE, IMAGE_SIZE, false);

    ByteBuffer byteBuffer;

    if(isQuant) {
        byteBuffer = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    } else {
        byteBuffer = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    }

    byteBuffer.order(ByteOrder.nativeOrder());

    int[] intValues = new int[IMAGE_SIZE * IMAGE_SIZE];
    resized.getPixels(intValues, 0, resized.getWidth(), 0, 0, resized.getWidth(), resized.getHeight());

    int pixel = 0;
    byteBuffer.rewind();

    for (int i = 0; i < IMAGE_SIZE; ++i) {
        for (int j = 0; j < IMAGE_SIZE; ++j) {
            final int val = intValues[pixel++];
            if(isQuant){
                byteBuffer.put((byte) ((val >> 16) & 0xFF));
                byteBuffer.put((byte) ((val >> 8) & 0xFF));
                byteBuffer.put((byte) (val & 0xFF));
            } else {
                byteBuffer.putFloat((((val >> 16) & 0xFF) - 0.5f) * 2.0f);
                byteBuffer.putFloat((((val >> 8) & 0xFF) - 0.5f) * 2.0f);
                byteBuffer.putFloat((((val) & 0xFF ) - 0.5f) *  2.0f);
            }
        }
    }
    return byteBuffer;
}

private Bitmap getOutputImage(ByteBuffer output){
    output.rewind();

    int outputWidth = IMAGE_SIZE;
    int outputHeight = IMAGE_SIZE;
    Bitmap bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888);
    int [] pixels = new int[outputWidth * outputHeight];
    for (int i = 0; i < outputWidth * outputHeight; i++) {
        int a = 0xFF;

        float r = (output.getFloat() + 1) / 2.0f;
        float g = (output.getFloat() + 1) / 2.0f;
        float b = (output.getFloat() + 1) / 2.0f;

        pixels[i] = a << 24 | ((int) r << 16) | ((int) g << 8) | (int) b;
    }
    bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight);
    return bitmap;
}

private void runInference(){

    ByteBuffer byteBufferX1 = convertBitmapToByteBuffer(bitmap_x1);
    ByteBuffer byteBufferX2 = convertBitmapToByteBuffer(bitmap_x2);

    Object[] inputs = {byteBufferX1, byteBufferX2};

    ByteBuffer byteBufferOutput;

    if(isQuant) {
        byteBufferOutput = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    } else {
        byteBufferOutput = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    }

    byteBufferOutput.order(ByteOrder.nativeOrder());
    byteBufferOutput.rewind();

    Map<Integer, Object> outputs = new HashMap<>();
    outputs.put(0, byteBufferOutput);

    interpreter.runForMultipleInputsOutputs(inputs, outputs);
    ByteBuffer out = (ByteBuffer) outputs.get(0);
    Bitmap outputBitmap = getOutputImage(out);

    // outputBitmap is just a full black image
}
android tensorflow tensorflow-lite
1个回答
0
投票

Java和Python解释器都基于C ++实现,因此结果应该相同。该错误应该在您的JAVA代码中。在这里,我认为您忘了乘以除以255。

© www.soinside.com 2019 - 2024. All rights reserved.