DeformableDETR 分离主干和探测器

问题描述 投票:0回答:0

基于this存储库,我尝试将主干网(ResNet-50)和检测器(Deformable DETR)分开,以便以后我可以更轻松地进行一些修改。因此,我的主干网不使用存储库中使用的 NestedTensor,并且其输出特征在被 Deformable DETR 接收时会转换为 NestedTensor。 在 Deformable DETR 方面进行了一些修改:

  • Backbone 不作为 init 中的参数接收,仅包含每个功能级别的输出通道的列表;
  • Transformer 不会作为 init 中的参数接收,为了简单起见,现在变压器总是直接在 init 内部使用 build_deforamble_transformer 构建;
  • forward接收x和features,x是backbone之前的原始输入,features是backbone的输出;
  • if self.num_feature_levels > len(srcs) 的条件已删除,因为现在保证两个值相同;

当收到 (2x3x69x69) x 虚拟输入(填充有 1)时,我在 tmp[..., :2] += 参考处收到错误:

RuntimeError: output with shape [100, 2] doesn't match the broadcast shape [2, 100, 2]
。我不明白为什么。

我的特征具有所有预期的形状:

  • 0 火炬.Size([2, 64, 69, 69])
  • 1 火炬.Size([2, 256, 69, 69])
  • 2 火炬.Size([2, 512, 35, 35])
  • 3 火炬.Size([2, 1024, 18, 18])
  • 4 火炬.Size([2, 2048, 9, 9])

下面是经过修改的 DeformableDETR 类,为了压缩而删除了文档和注释。

class DeformableDETR(BaseDetector):
    def __init__(self, channels, num_feature_levels, num_classes, num_queries,
                 aux_loss=True, with_box_refine=False, two_stage=False, **kwargs):
        args = kwargs.pop('args')
        hidden_dim = kwargs.pop('hidden_dim')
        position_embedding = kwargs.pop('position_embedding')
        super(DeformableDETR, self).__init__(**kwargs)

        self.position_encoding = build_position_encoding(hidden_dim, position_embedding)

        self.num_queries = num_queries
        transformer = build_deforamble_transformer(args=args)
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.num_feature_levels = num_feature_levels
        if not two_stage:
            self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
        if num_feature_levels > 1:
            num_backbone_outs = len(channels)
            input_proj_list = []
            for _ in range(num_backbone_outs):
                in_channels = channels[_]
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
                in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            self.input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(channels[0], hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                )])
        self.aux_loss = aux_loss
        self.with_box_refine = with_box_refine
        self.two_stage = two_stage

        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
        for proj in self.input_proj:
            nn.init.xavier_uniform_(proj[0].weight, gain=1)
            nn.init.constant_(proj[0].bias, 0)

        num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
        if with_box_refine:
            self.class_embed = _get_clones(self.class_embed, num_pred)
            self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
            nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
            # hack implementation for iterative bounding box refinement
            self.transformer.decoder.bbox_embed = self.bbox_embed
        else:
            nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
            self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
            self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
            self.transformer.decoder.bbox_embed = None
        if two_stage:
            self.transformer.decoder.class_embed = self.class_embed
            for box_embed in self.bbox_embed:
                nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)

    def forward(self, x: torch.Tensor, features: List[torch.Tensor], *args):
        out = []
        pos = []

        for f in features:
            nested_x = nested_tensor_from_tensor_list(f)
            mask = F.interpolate(nested_x.mask[None].float(), size=f.shape[-2:]).to(torch.bool)[0]
            out.append(NestedTensor(f, mask))
            pos.append(self.position_encoding(out[-1]).to(out[-1].tensors.dtype))

        srcs = []
        masks = []
        for l, feat in enumerate(out):
            src, mask = feat.decompose()
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            assert mask is not None
        
        query_embeds = None
        if not self.two_stage:
            query_embeds = self.query_embed.weight
        hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds)
        
        outputs_classes = []
        outputs_coords = []
        for lvl in range(len(features)):
            print(lvl, features[lvl].shape)
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            outputs_class = self.class_embed[lvl](hs[lvl])
            tmp = self.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                tmp[..., :2] += reference # <<< ERROR happens here
            outputs_coord = tmp.sigmoid()
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)
        outputs_class = torch.stack(outputs_classes)
        outputs_coord = torch.stack(outputs_coords)

        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)

        if self.two_stage:
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
        return out
deep-learning pytorch object-detection
© www.soinside.com 2019 - 2024. All rights reserved.