SSD网络是一种单阶段的目标检测方法,目标检测方法旨在给定的图片中找出目标物体的坐标位置和所属类别。我们在这里来梳理一下训练的大致流程谨供参考,我参考的算法实现为:https://github.com/amdegroot/ssd.pytorch
1.特征提取
SSD网络的输入一般是 300x300x3的原始图片矩阵、n个坐标标签和n个分类标签,其中n代表每幅图中目标物体的个数。其特征提取网络直接由卷积网络得到最终的预测值,共使用了6种不同尺寸的特征图:38x38、19x19、10x10、5x5、3x3、1x1,最终得到8732个预测框,具体网络结构如下图所示:
2.训练流程
SSD网络在特征提取会后会直接得到 8732x4的坐标调整值预测和 8732x21的置信度预测及具体分类预测,其训练流程如下:
1.计算anchor与gt_bbox之间的IOU值,根据正负样本阈值选择正负样本,并计算得到样本标签
2.根据样本的预测值和标签值计算坐标损失、置信度损失和分类损失,其中坐标损失只用正样本的计算