如何将 torch 模型中的动态输入获取到 onnx 模型? 我用dynamic_axes给出输入,但推理中的输出不是动态的。 我的代码:
input_names = ['speakers', 'texts', 'src_lens', 'max_src_len']
output_names = ['output', 'postnet_output', 'p_predictions', 'e_predictions', 'log_d_predictions', 'd_rounded',
'src_masks', 'mel_masks', 'src_lens', 'mel_lens']
dynamic_axes = {
"texts": {1: "texts_len"},
"output": {1: "output_len"},
"postnet_output": {1: "postnet_output_len"},
"p_predictions": {1: "p_predictions_len"},
"e_predictions": {1: "e_predictions_len"},
"log_d_predictions": {1: "log_d_predictions_len"},
"d_rounded": {1: "d_rounded_len"},
"src_masks": {1: "src_masks_len"}
}
texts_len = 10
speakers = torch.tensor([0])
texts = torch.randint(1, 200, (1, texts_len))
text_lens = torch.tensor([texts_len])
max_len = torch.from_numpy(np.array(texts_len)).to(device)
torch.onnx.export(model, args=(speakers, texts, text_lens, max_len), f="./FastSpeech_2.onnx",
input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11)
在代码中,我使用
src_lens=10
,没问题。
但是在这个 onnx 模型的推理中,当我输入 src_lens=50
或其他输入时,我得到这个错误:
2022-01-18 16:29:38.644831855 [E:onnxruntime:, sequential_executor.cc:346 Execute] Non-zero status code returned while running Split node. Name:'Split_2888' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={27,256} NumOutputs=10 Num entries in 'split' (must equal number of outputs) was 10 Sum of sizes in 'split' (must equal size of selected axis) was 10
Traceback (most recent call last):
File "torch2onnx_2.py", line 497, in <module>
onnx_mode_test()
File "torch2onnx_2.py", line 471, in onnx_mode_test
ort_outs = ort_session.run(None, ort_inputs)
File "/root/anaconda3/envs/tts_fffan/lib/python3.6/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 192, in run
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Split node. Name:'Split_2888' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={27,256} NumOutputs=10 Num entries in 'split' (must equal number of outputs) was 10 Sum of sizes in 'split' (must equal size of selected axis) was 10
好像输入len必须是
10
,而且不能是动态的
有人帮助我吗?我和你有同样的问题..有没有人有办法解决它?