• 学习Faster R-CNN代码faster_rcnn(八)


    Faster R-CNN源代码中faster_rcnn文件夹中包含三个文件 faster_rcnn.py,resnet.pyvgg16.py

    1.faster_rcnn.py注释

      1 class _fasterRCNN(nn.Module):
      2     """ faster RCNN """
      3     def __init__(self, classes, class_agnostic):#class-agnostic 方式只回归2类bounding box,即前景和背景
      4         super(_fasterRCNN, self).__init__()
      5         self.classes = classes #类别
      6         self.n_classes = len(classes)#类别数
      7         self.class_agnostic = class_agnostic #前景背景类
      8         # loss 两种loss
      9         self.RCNN_loss_cls = 0
     10         self.RCNN_loss_bbox = 0
     11 
     12         # define rpn 定义RPN网络
     13         self.RCNN_rpn = _RPN(self.dout_base_model)
     14         self.RCNN_proposal_target = _ProposalTargetLayer(self.n_classes)#候选区域对应gt
     15         self.RCNN_roi_pool = _RoIPooling(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)#POOLING
     16         self.RCNN_roi_align = RoIAlignAvg(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)
     17 
     18         self.grid_size = cfg.POOLING_SIZE * 2 if cfg.CROP_RESIZE_WITH_MAX_POOL else cfg.POOLING_SIZE
     19         self.RCNN_roi_crop = _RoICrop()
     20 
     21     def forward(self, im_data, im_info, gt_boxes, num_boxes):#图像 图像信息 标注信息 框数目
     22         batch_size = im_data.size(0)
     23 
     24         im_info = im_info.data
     25         gt_boxes = gt_boxes.data
     26         num_boxes = num_boxes.data
     27 
     28         # feed image data to base model to obtain base feature map
     29         #将图像数据馈送到基础模型以获得基础特征图
     30         base_feat = self.RCNN_base(im_data)
     31 
     32         # feed base feature map tp RPN to obtain rois
     33         # 特征图反馈到RPN得到ROIS
     34         rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes, num_boxes)
     35 
     36         # if it is training phrase, then use ground trubut bboxes for refining
     37         #如果是在训练 用ground truth回归
     38         if self.training:
     39             roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes)
     40             rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data
     41 
     42             rois_label = Variable(rois_label.view(-1).long())
     43             rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
     44             rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
     45             rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
     46         else:
     47             rois_label = None
     48             rois_target = None
     49             rois_inside_ws = None
     50             rois_outside_ws = None
     51             rpn_loss_cls = 0
     52             rpn_loss_bbox = 0
     53 
     54         rois = Variable(rois)
     55         # do roi pooling based on predicted rois
     56         #进行ROI POOLING,下面pooling方式
     57 
     58         if cfg.POOLING_MODE == 'crop':
     59             # pdb.set_trace()
     60             # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
     61             grid_xy = _affine_grid_gen(rois.view(-1, 5), base_feat.size()[2:], self.grid_size)
     62             grid_yx = torch.stack([grid_xy.data[:,:,:,1], grid_xy.data[:,:,:,0]], 3).contiguous()
     63             pooled_feat = self.RCNN_roi_crop(base_feat, Variable(grid_yx).detach())
     64             if cfg.CROP_RESIZE_WITH_MAX_POOL:
     65                 pooled_feat = F.max_pool2d(pooled_feat, 2, 2)
     66         elif cfg.POOLING_MODE == 'align':
     67             pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5))
     68         elif cfg.POOLING_MODE == 'pool':
     69             pooled_feat = self.RCNN_roi_pool(base_feat, rois.view(-1,5))
     70 
     71         # feed pooled features to top model
     72         #pooling后的特征反馈到上次模型
     73         pooled_feat = self._head_to_tail(pooled_feat)
     74 
     75         # compute bbox offset
     76         #计算bounding box的偏移
     77         bbox_pred = self.RCNN_bbox_pred(pooled_feat)
     78         if self.training and not self.class_agnostic:
     79             # select the corresponding columns according to roi labels
     80             # 根据roi标签选择相应的列
     81             bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
     82             bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
     83             bbox_pred = bbox_pred_select.squeeze(1)
     84 
     85         # compute object classification probability
     86         # 计算对象分类概率
     87         cls_score = self.RCNN_cls_score(pooled_feat)
     88         cls_prob = F.softmax(cls_score, 1)
     89 
     90         RCNN_loss_cls = 0
     91         RCNN_loss_bbox = 0
     92 
     93         if self.training:
     94             # classification loss
     95             #分类损失
     96             RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)
     97 
     98             # bounding box regression L1 loss
     99             #回归损失
    100             RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)
    101 
    102 
    103         cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
    104         bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)
    105 
    106         return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
    107     
    108     #初始化权重
    109     def _init_weights(self):
    110         def normal_init(m, mean, stddev, truncated=False):#均值 标准差 
    111             #截断正态 随机正态
    112             """
    113             weight initalizer: truncated normal and random normal.
    114             """
    115             # x is a parameter
    116             if truncated:
    117                 m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
    118             else:
    119                 m.weight.data.normal_(mean, stddev)
    120                 m.bias.data.zero_()
    121 
    122         normal_init(self.RCNN_rpn.RPN_Conv, 0, 0.01, cfg.TRAIN.TRUNCATED)
    123         normal_init(self.RCNN_rpn.RPN_cls_score, 0, 0.01, cfg.TRAIN.TRUNCATED)
    124         normal_init(self.RCNN_rpn.RPN_bbox_pred, 0, 0.01, cfg.TRAIN.TRUNCATED)
    125         normal_init(self.RCNN_cls_score, 0, 0.01, cfg.TRAIN.TRUNCATED)
    126         normal_init(self.RCNN_bbox_pred, 0, 0.001, cfg.TRAIN.TRUNCATED)
    127 
    128     def create_architecture(self):
    129         self._init_modules()
    130         self._init_weights()

    ref:https://blog.csdn.net/weixin_43872578/article/details/87930953

  • 相关阅读:
    SQL的四种连接(内连接,外连接)
    MySQL连表操作之一对多
    [转]Mysql连表之多对多
    Hibernate笔记二
    Hibernate框架报错:org.hibernate.PropertyAccessException: IllegalArgumentException occurred while calling setter of com.mikey.hibernate.domain.Person.pid
    Hibernate框架:org.hibernate.exception.SQLGrammarException: Cannot open connection at org.hibernate.exception.SQLStateConverter.convert(SQLStateConverter.java92)
    [转]网络编程三要素
    Hibernate笔记一
    JavaScript高级特征之面向对象笔记
    Myeclipse创建HTML文件中文显示乱码问题
  • 原文地址:https://www.cnblogs.com/wind-chaser/p/11360073.html
Copyright © 2020-2023  润新知