PyTorch SSD (VGG16 backbone) 不学习

问题描述 投票:0回答:1

我目前正在使用 PyTorch 尝试在自定义数据集上训练 SSD 检测器。我之前在同一框架 Faster RCNN 中完成了另一个模型的训练,并对结果感到满意 (~80 mAP @IoU=0.50)。然而,训练SSD似乎根本不起作用。

这是在自定义数据集上训练 25 个时期(数据集中约 5000 张图像)后的指标,小批量大小为 32,学习率为 1e-3。

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.002
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.015
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.001
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.011
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.011
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.018
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.062
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.174
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.100

这是我尝试构建模型的方式(所有默认情况下首先进行最低限度的工作)。

model = torchvision.models.detection.ssd300_vgg16(weights_backbone=torchvision.models.VGG16_Weights)
in_channels = det_utils.retrieve_out_channels(model.backbone, (300, 300))
num_anchors = model.anchor_generator.num_anchors_per_location()
model.head = SSDHead(in_channels=in_channels, num_anchors=num_anchors, num_classes=3)

我正在尝试为 VGG (Imagenet) 加载预训练的主干,然后替换头部,我认为这会导致一个模型,该模型具有特征提取器的预训练权重和头部的 xavier 初始化权重。这是它应该如何构建?

现在我确信图像已正确加载,因为我使用完全相同的

__getitem__()
函数,我曾为 Faster RCNN 加载相同的数据(+ 我通过可视化边界框用 SSD 代码确认了这一点) .数据集中最常见的对象也不是特别小,而是很大(图像大小的 10-40%)。

我努力寻找 SSD 的任何工作示例,只有 Faster RCNN 和其他一些。所以我所做的是基于我对 SSD 的理解和检查存储库中的代码,这就是为什么我相信我可能在模型初始化中做错了。

编辑:

这是优化器

torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
  • PyTorch 版本:1.13.0a0+git3c2c2cc
  • 火炬视觉版本:0.14.0a0+6ca9c76
  • 显卡:RTX 3070
  • CUDA 版本:11.2
  • CuDNN 版本:8.2.1

从源代码编译的 PyTorch 和 Torchvision

python pytorch object-detection torchvision
1个回答
0
投票

我遇到了同样的问题。你解决了这个问题吗? 😥

© www.soinside.com 2019 - 2024. All rights reserved.