PyTorch、pt模型文件转换为torchScript.ts文件

问题描述 投票:0回答:0

我有一个名为 model.pt 的权重模型,用于头部 CT 扫描的大脑分割。 我如何将其转换为 torchscript 文件,以便我可以使用该模型进行部署,

网络定义:

3dUNet, 
in channel: 1(image), 
out channel: 2(brain label and background)

输入定义:

 "image": {
                "type": "image",
                "format": "hounsfield",
                "modality": "CT",
                "num_channels": 1,
                "spatial_shape": [
                    96,
                    96,
                    96
                ],
                "dtype": "float32",
                "value_range": [
                    0,
                    1
                ],
                "is_patch_data": true,
                "channel_def": {
                    "0": "image"
                }
            }
        },

训练/验证分割: 13 个图像用于训练,3 个图像用于验证

输出定义:

          "pred": {
              "type": "image",
              "format": "segmentation",
              "num_channels": 2,
              "spatial_shape": [
                  96,
                  96,
                  96
              ],
              "dtype": "float32",
              "value_range": [
                  0,
                  1
              ],
              "is_patch_data": true,
              "channel_def": {
                  "0": "background",
                  "1": "brain"
              }
          }
      

现在,我如何使用跟踪/脚本转换为 torchsctipt。 这些信息够了吗?

我试过了

import torch

model = torch.load('model/model.pt')

example = torch.rand(13, 96, 96, 96)

traced_script_module = torch.jit.script(model, (example))
torch.save(traced_script_module, "model/traced_resnet_model.ts")

我只使用了模型输入尺寸,我也尝试过

torch.jit.trace
。但都失败了。

任何帮助将非常感激。

pytorch model jit torchscript
© www.soinside.com 2019 - 2024. All rights reserved.