我正在尝试将预训练的 Mask R-CNN 模型导出为 ONNX 格式。由于该模型的基本配置具有以下结构(这里我添加了
batch_size
作为动态轴):
我想自定义我的模型并将
batch_size
添加到输出(这意味着我需要为每个输出添加新的暗淡)。
我编写了以下代码以使其成为可能:
class MaskRCNNModel(torch.nn.Module):
def __init__(self):
super(MaskRCNNModel, self).__init__()
self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='DEFAULT')
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=7)
self.model.load_state_dict(torch.load("saved_dict.torch"))
def forward(self, input):
outputs = self.model.forward(input)
boxes = []
labels = []
scores = []
masks = []
for result in outputs:
box, label, score, mask = result.values()
boxes.append(box)
labels.append(label)
scores.append(score)
masks.append(mask)
return boxes, labels, scores, masks
maskrcnn_model = MaskRCNNModel()
maskrcnn_model.eval()
maskrcnn_model.to(device)
x = torch.rand(1, 3, 512, 512)
x = x.to(device)
maskrcnn_model(x)
torch.onnx.export(maskrcnn_model,
x,
"base_model_100_epochs.onnx",
opset_version=11,
input_names=["input"],
output_names=["boxes", "labels", "scores", "masks"])
但是上面的代码不会更改任何导出参数。输出的结构保持不变:
我应该如何自定义
forward
方法才能将 batch_size
添加到 ONNX 模型输出中?
根据我最初的评论,我不鼓励使用 ONNX 部署大多数
torchvision
模型。它是一个很棒的模块,只是它最初编写的目的并不是为了与静态图很好地配合。
如果考虑吞吐量,则这种实现 Mask R-CNN 不是最佳选择。对于早期的 ONNX opsets,我让这个模型在回退到 CPU 时将大部分执行时间用于 h2d/d2h 操作。我建议检查 ultralytics 的 YOLOv8 以获取实例分割的更新版本,或者在 github 上找到的一些静态实现。 Torchvision Mask R-CNN 输出
batch = torch.randn((2, 3, 256, 256)) # Input two images
output = mask_rcnn(batch) # run model
results1, results2 = output # One dictionary per batch
for key, value in results1:
print(key, value.shape)
>>> boxes [10, 4]
>>> labels [10]
>>> scores [10]
>>> masks [10, 1, 256, 256]
for key, value in results2:
print(key, value.shape)
>>> boxes [3, 4]
>>> labels [3]
>>> scores [3]
>>> masks [3, 1, 256, 256]
为什么你的方法不起作用
python
类型。在
torch.onnx.export
期间,列表、字典、元组等没有特殊含义,它们的条目要么保存为张量,要么保存为常量。因此,您的自定义前向传递所做的唯一事情就是更改输出的顺序,例如与前面的示例输出变换为>>> boxes1 [10, 4]
>>> labels1 [10]
>>> scores1 [10]
>>> masks1 [10, 1, 256, 256]
>>> boxes2 [3, 4]
>>> labels2 [3]
>>> scores2 [3]
>>> masks2 [3, 1, 256, 256]
到
>>> boxes1 [10, 4]
>>> boxes2 [3, 4]
>>> labels1 [10]
>>> labels2 [3]
>>> scores1 [10]
>>> scores2 [3]
>>> masks1 [10, 1, 256, 256]
>>> masks2 [3, 1, 256, 256]
值得一读,了解导出过程中如何解释 python
和
torch
类型。目标
boxes [batch_size, num_detections, 4]
labels [batch_size, num_detections]
scores [batch_size, num_detections]
masks [batch_size, num_detections, 1, 256, 256]
我们立即发现,如果不应用任何技巧,这是不可能的。由于批次中的不同图像将具有不同数量的预测对象,因此我们无法创建第一个索引中包含
10
边界框、第二个索引中包含
4
边界框的张量。解决方案 - 填充
def forward(self, input):
# Maximum number of detections the vision model will output per batch
max_detections = self.model.roi_heads.detections_per_img
# Variables for output tensor shapes
# Use tensor.size instead of tensor.shape for dynamic inputs
batch_size, _, input_height, input_width = input.shape
# Create batched output tensors
all_boxes = torch.zeros((batch_size, max_detections, 4))
all_labels = torch.zeros((batch_size, max_detections))
all_scores = torch.zeros((batch_size, max_detections))
# Masks are returned with a redundant channel in the second dimension
all_masks = torch.zeros((batch_size, max_detections, 1, input_height, input_width))
# Number of detections per batch
detections_per_batch = torch.zeros((batch_size, 1))
# Run forward pass
outputs = self.model.forward(input)
for idx, result in enumerate(outputs):
boxes, labels, scores, masks = result.values()
# Number of detections for batch
n_dets = boxes.size(0)
detections_per_batch[idx] = n_dets
# Paste batch results into output tensors
all_boxes[idx, : n_dets] = boxes
all_labels[idx, : n_dets] = labels
all_scores[idx, : n_dets] = scores
all_masks[idx, : n_dets] = masks
return detections_per_batch, all_boxes, all_labels, all_scores, all_masks
此前向传递创建了可能保存的输出张量 所有对象检测,并将每批实现的对象检测复制到其中。为了跟踪哪些条目是零填充的以及哪些是实际检测,在 Mask R-CNN 输出的顶部返回一个张量
detections_per_batch
。然后用它从 ONNX 输出中提取真实的预测
for preds, boxes, labels, scores, masks in zip(*outputs):
detected_boxes = boxes[: preds]
detected_labels = labels[: preds]
...
注意事项
model.roi_heads.detections_per_img
来限制此上限。