我在 Tensorflow 的帮助下用 Python 创建了一个图像分割模型,它为我提供了一个掩码作为输出。
在 Python 中,我可以简单地通过两行加载模型并生成输出(作为输入,我使用大小为 128x128 的灰度图像,这些图像以批处理的形式传递:[1,128,128,1])。
model.load_weights(path/to/model)
test_preds = model.predict(X_test)
下一步是进行二值化,这给我一个仅包含值 255 或 0 的掩码。
preds_test_thresh = (test_preds >= 0.5).astype(np.uint8)
test_img = preds_test_thresh[1, :, :, 0]
我现在的目标是在 C++ 中使用这个模型。为此,我首先将模型转换为 TF-Lite 模型,现在希望将其加载到 C++ 中并生成输出。
我的做法如下:
// Create model from file
auto model = tflite::FlatBufferModel::BuildFromFile("path/to/model");
if (model == nullptr)
wxLogMessage("Model not loaded");
else
wxLogMessage("Model loaded");
// Create an Interpreter with an InterpreterBuilder.
std::unique_ptr<Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter)
wxLogMessage("Interpreter not loaded");
if (interpreter->AllocateTensors() != kTfLiteOk)
wxLogMessage("Allocation failed");
else
wxLogMessage("Allocation success");
// load image; get blue channel; resize to 128x128
std::string image_path = samples::findFile("path/to/image");
cv::Mat img = cv::imread(image_path);
wxLogMessage(wxString::Format("%d x %d x %d", img.size[1], img.size[0], img.channels()));
Mat bgr[3];
split(img, bgr);
Mat channelImg = bgr[0];
Mat inputImg;
channelImg.convertTo(inputImg, CV_32FC1, 1.0 / 255.0);
cv::resize(inputImg, inputImg, cv::Size(128, 128));
wxLogMessage(wxString::Format("%d x %d x %d", inputImg.size[1], inputImg.size[0], inputImg.channels()));
// Fill input buffer
float* input = interpreter->typed_input_tensor<float>(0);
memcpy(input, inputImg.data, 128 * 128 * sizeof(float));
// invoke interpreter
if (interpreter->Invoke() != kTfLiteOk) {
wxLogMessage("Failed to invoke");
}
// get output
float* output = interpreter->typed_output_tensor<float>(0);
我从各种示例中获取了 C++ 代码的想法并获得了浮点值。 但是,我没有得到图像作为输出,这就是为什么我已经尝试了一些生成 Mat 对象的方法,但不幸的是我没有得到正确的输出图像。
所以我现在的问题是,如何从模型的输出(如上面的 Python 中)生成可以继续工作的图像?或者我是否必须更改上面的 C++ 代码中的某些内容才能获得输出?
我自己解决了这个问题,代码如下:
// File path to the TensorFlow Lite model (.tflite)
const char* model_path = "pat";
// Load the TensorFlow Lite model
auto model = tflite::FlatBufferModel::BuildFromFile(model_path);
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder builder(*model, resolver);
builder(&interpreter);
// Check whether the interpreter has been successfully created
if (!interpreter)
wxLogMessage("Interpreter not loaded");
// Assign TensorFlow Lite model
interpreter->AllocateTensors();
// Resize image to fit the model input [128x128x1]
const int image_width = 128;
const int image_height = 128;
cv::Mat input_image = cv::imread("path/to/image");
Mat bgr[3];
split(input_image, bgr);
input_image = bgr[0];
if (input_image.empty())
{
wxLogMessage(wxString::Format("Could not read the image: %s", "path/to/image"));
}
cv::resize(input_image, input_image, cv::Size(image_width, image_height));
imshow("Display window", input_image);
waitKey(0);
// Pointer to the input tensor of the interpreter
float* input_tensor_data = interpreter->typed_input_tensor<float>(0);
// Copy the image pixels into the input tensor
for (int y = 0; y < image_height; ++y) {
for (int x = 0; x < image_width; ++x) {
input_tensor_data[y * image_width + x] = static_cast<float>(input_image.at<uchar>(y, x));
}
}
// Run the model
interpreter->Invoke();
int output_tensor_count = interpreter->outputs().size();
for (int i = 0; i < output_tensor_count; ++i) {
int output_tensor_index = interpreter->outputs()[i];
TfLiteIntArray* output_dims = interpreter->tensor(output_tensor_index)->dims;
}
int output_tensor_index = 0;
TfLiteTensor* output_tensor = interpreter->tensor(output_tensor_index);
// output image with size of [128x128x1]
const int output_image_width = 128;
const int output_image_height = 128;
// Pointer to the output sensor data
float* output_data = interpreter->typed_output_tensor<float>(output_tensor_index);
cv::Mat output_image(output_image_height, output_image_width, CV_8UC1);
for (int y = 0; y < output_image_height; ++y) {
for (int x = 0; x < output_image_width; ++x) {
output_image.at<uchar>(y, x) = static_cast<uchar>(output_data[y * output_image_width + x] * 255.0);
}
}
imshow("Display window", output_image);
waitKey(0);