如何加载export_inference_graph.py保存的训练模型?

问题描述 投票:0回答:1

我下面是一个使用tensorflow的1.15.0对象检测API的例子,教程中明确了以下几个方面。

  • 如何下载一个模型
  • 如何加载一个带有.xml文件的自定义数据库,并从中制作.cvs文件,然后制作.record文件。
  • 如何配置培训流水线
  • 如何获得张量板图
  • 如何训练净节约检查点(使用model_main.py)。
  • 如何导出(保存)模型(使用export_inference_graph.py)。

然而,我一直无法完成的是,加载保存的模型来使用它。tf.saved_model.loader.load(sess, flags, export_dir但我得到

INFO:tensorflow:Saver not created because there are no variables in the graph to restore.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.

中给出的文件夹 export_dir 具有以下结构。

+dir
   +saved_model
      -saved_model.pb
   -model.ckpt.data-00000-of-00001
   -model.ckpt.index
   -checkpoint
   -frozen_inference_graph.pb
   -model.ckpt.meta
   -pipeline.config

我的最终目标是用相机捕捉图像,并将其输入到网络中进行实时物体检测.作为中间步骤,现在我只想能够输入一张图片并得到输出. 我能够训练网络,但现在我不能使用它。

先谢谢你。

tensorflow conv-neural-network object-detection-api transfer-learning tensorflow-model-garden
1个回答
2
投票

我发现 如何下载模型的例子 由于例子中下载的文件的文件夹格式和我的代码中的一样,所以我只需要调整它。

下载模型的最终函数是

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

然后我用这个函数创建了这个新的函数

def load_local_model(model_path):
  model_dir = pathlib.Path(model_path)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model

起初这并不奏效,因为 tf.saved_model.load 预期有3个参数,但通过导入两个 输入 块,我不知道是哪个导入做了手脚,为什么(当我得到答案时,我会编辑这个答案),但目前这段代码工作,这个例子让我们做更多的事情。

导入块如下

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

编辑真正需要这样做的是下面的块。

import os
import pathlib


if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.

%%bash 
cd models/research
pip install .

否则,这个导入块将无法工作。

from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
© www.soinside.com 2019 - 2024. All rights reserved.