tf_rep.export_graph(tf_model_path):KeyError:'input.1

问题描述 投票:0回答:1

我正在尝试将

onnx
模型转换为
tflite
,我在执行行
tf_rep.export_graph(tf_model_path)
时遇到错误。这个问题之前在 SO 中被问过,但没有提供明确的解决方案。

已安装的要求:

tensorflow: 2.12.0
onnx 1.14.0
onnx-tf 1.10.0
Python 3.10.12

  import torch
  import onnx
  import tensorflow as tf
  import onnx_tf
  from torchvision.models import resnet50

  # Load the PyTorch ResNet50 model
  pytorch_model = resnet50(pretrained=True)
  pytorch_model.eval()

  # Export the PyTorch model to ONNX format
  input_shape = (1, 3, 224, 224)
  dummy_input = torch.randn(input_shape)
  onnx_model_path = 'resnet50.onnx'
  torch.onnx.export(pytorch_model, dummy_input, onnx_model_path, opset_version=12, verbose=False)

  # Load the ONNX model
  onnx_model = onnx.load(onnx_model_path)

  # Convert the ONNX model to TensorFlow format
  tf_model_path = 'resnet50.pb

  onnx_model = onnx.load(onnx_model_path)
  from onnx_tf.backend import prepare

  tf_rep = prepare(onnx_model)
  tf_rep.export_graph(tf_model_path)    #ERROR

错误:

WARNING:absl:`input.1` is not a valid tf.function parameter name. Sanitizing to `input_1`.
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-f35b83c104b8> in <cell line: 8>()
    6 tf_model_path = 'resnet50'
    7 tf_rep = prepare(onnx_model)
----> 8 tf_rep.export_graph(tf_model_path)

35 frames
/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py in tf__conv(cls, node, input_dict, transpose)
    17                 do_return = False
    18                 retval_ = ag__.UndefinedReturnValue()
---> 19                 x = ag__.ld(input_dict)[ag__.ld(node).inputs[0]]
    20                 x_rank = ag__.converted_call(ag__.ld(len), (ag__.converted_call(ag__.ld(x).get_shape, (), None, fscope),), None, fscope)
    21                 x_shape = ag__.converted_call(ag__.ld(tf_shape), (ag__.ld(x), ag__.ld(tf).int32), None, fscope)

KeyError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend_tf_module.py", line 99, in __call__  *
        output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op  *
        return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/handler.py", line 59, in handle  *
        return ver_handle(node, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv.py", line 15, in version_11  *
        return cls.conv(node, kwargs["tensor_dict"])
    File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py", line 29, in conv  *
        x = input_dict[node.inputs[0]]

    KeyError: 'input.1'
python deep-learning tensorflow-lite onnx
1个回答
0
投票

问题出在

onnx
模型中的参数名称。

import onnx

onnx_model = onnx.load(onnx_model_path)
print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])

模型输入:['input.1']

这里

tflite
无法解析
input.1
,必须替换为
input_1
。下面的代码可以做到这一点:

import onnx
from onnx import helper

onnx_model = onnx.load(onnx_model_path)

# Define a mapping from old names to new names
name_map = {"input.1": "input_1"}

# Initialize a list to hold the new inputs
new_inputs = []

# Iterate over the inputs and change their names if needed
for inp in onnx_model.graph.input:
    if inp.name in name_map:
        # Create a new ValueInfoProto with the new name
        new_inp = helper.make_tensor_value_info(name_map[inp.name],
                                                inp.type.tensor_type.elem_type,
                                                [dim.dim_value for dim in inp.type.tensor_type.shape.dim])
        new_inputs.append(new_inp)
    else:
        new_inputs.append(inp)

# Clear the old inputs and add the new ones
onnx_model.graph.ClearField("input")
onnx_model.graph.input.extend(new_inputs)

# Go through all nodes in the model and replace the old input name with the new one
for node in onnx_model.graph.node:
    for i, input_name in enumerate(node.input):
        if input_name in name_map:
            node.input[i] = name_map[input_name]

# Save the renamed ONNX model
onnx.save(onnx_model, 'resnet50-new.onnx')

新参数如下所示:

模型输入:['input_1']

输出

tflite
文件生成无错误。

import onnx

onnx_model_path = 'resnet50-new.onnx'
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare

tf_model_path = 'resnet50'
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path)
© www.soinside.com 2019 - 2024. All rights reserved.