• MMDetection源码解析:Faster RCNN(5)--TwoStageDetector类


    TwoStageDetector类定义在mmdetmodelsdetectors ew_stage.py中:

    import torch
    import torch.nn as nn
    
    # from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
    from ..builder import DETECTORS, build_backbone, build_head, build_neck
    from .base import BaseDetector
    
    
    @DETECTORS.register_module()
    class TwoStageDetector(BaseDetector):
        """Base class for two-stage detectors.
    
        Two-stage detectors typically consisting of a region proposal network and a
        task-specific regression head.
        """
    
        def __init__(self,
                     backbone,
                     neck=None,
                     rpn_head=None,
                     roi_head=None,
                     train_cfg=None,
                     test_cfg=None,
                     pretrained=None):
            super(TwoStageDetector, self).__init__()
            self.backbone = build_backbone(backbone)
    
            if neck is not None:
                self.neck = build_neck(neck)
    
            if rpn_head is not None:
                rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
                rpn_head_ = rpn_head.copy()
                rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
                self.rpn_head = build_head(rpn_head_)
    
            if roi_head is not None:
                # update train and test cfg here for now
                # TODO: refactor assigner & sampler
                rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
                roi_head.update(train_cfg=rcnn_train_cfg)
                roi_head.update(test_cfg=test_cfg.rcnn)
                self.roi_head = build_head(roi_head)
    
            self.train_cfg = train_cfg
            self.test_cfg = test_cfg
    
            self.init_weights(pretrained=pretrained)
    
        @property
        def with_rpn(self):
            """bool: whether the detector has RPN"""
            return hasattr(self, 'rpn_head') and self.rpn_head is not None
    
        @property
        def with_roi_head(self):
            """bool: whether the detector has a RoI head"""
            return hasattr(self, 'roi_head') and self.roi_head is not None
    
        def init_weights(self, pretrained=None):
            """Initialize the weights in detector.
    
            Args:
                pretrained (str, optional): Path to pre-trained weights.
                    Defaults to None.
            """
            super(TwoStageDetector, self).init_weights(pretrained)
            self.backbone.init_weights(pretrained=pretrained)
            if self.with_neck:
                if isinstance(self.neck, nn.Sequential):
                    for m in self.neck:
                        m.init_weights()
                else:
                    self.neck.init_weights()
            if self.with_rpn:
                self.rpn_head.init_weights()
            if self.with_roi_head:
                self.roi_head.init_weights(pretrained)
    
        def extract_feat(self, img):
            """Directly extract features from the backbone+neck."""
            x = self.backbone(img)
            if self.with_neck:
                x = self.neck(x)
            return x
    
        def forward_dummy(self, img):
            """Used for computing network flops.
    
            See `mmdetection/tools/get_flops.py`
            """
            outs = ()
            # backbone
            x = self.extract_feat(img)
            # rpn
            if self.with_rpn:
                rpn_outs = self.rpn_head(x)
                outs = outs + (rpn_outs, )
            proposals = torch.randn(1000, 4).to(img.device)
            # roi_head
            roi_outs = self.roi_head.forward_dummy(x, proposals)
            outs = outs + (roi_outs, )
            return outs
    
        def forward_train(self,
                          img,
                          img_metas,
                          gt_bboxes,
                          gt_labels,
                          gt_bboxes_ignore=None,
                          gt_masks=None,
                          proposals=None,
                          **kwargs):
            """
            Args:
                img (Tensor): of shape (N, C, H, W) encoding input images.
                    Typically these should be mean centered and std scaled.
    
                img_metas (list[dict]): list of image info dict where each dict
                    has: 'img_shape', 'scale_factor', 'flip', and may also contain
                    'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                    For details on the values of these keys see
                    `mmdet/datasets/pipelines/formatting.py:Collect`.
    
                gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                    shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
    
                gt_labels (list[Tensor]): class indices corresponding to each box
    
                gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                    boxes can be ignored when computing the loss.
    
                gt_masks (None | Tensor) : true segmentation masks for each box
                    used if the architecture supports a segmentation task.
    
                proposals : override rpn proposals with custom proposals. Use when
                    `with_rpn` is False.
    
            Returns:
                dict[str, Tensor]: a dictionary of loss components
            """
            x = self.extract_feat(img)
    
            losses = dict()
    
            # RPN forward and loss
            if self.with_rpn:
                proposal_cfg = self.train_cfg.get('rpn_proposal',
                                                  self.test_cfg.rpn)
                rpn_losses, proposal_list = self.rpn_head.forward_train(
                    x,
                    img_metas,
                    gt_bboxes,
                    gt_labels=None,
                    gt_bboxes_ignore=gt_bboxes_ignore,
                    proposal_cfg=proposal_cfg)
                losses.update(rpn_losses)
            else:
                proposal_list = proposals
    
            roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
                                                     gt_bboxes, gt_labels,
                                                     gt_bboxes_ignore, gt_masks,
                                                     **kwargs)
            losses.update(roi_losses)
    
            return losses
    
        async def async_simple_test(self,
                                    img,
                                    img_meta,
                                    proposals=None,
                                    rescale=False):
            """Async test without augmentation."""
            assert self.with_bbox, 'Bbox head must be implemented.'
            x = self.extract_feat(img)
    
            if proposals is None:
                proposal_list = await self.rpn_head.async_simple_test_rpn(
                    x, img_meta)
            else:
                proposal_list = proposals
    
            return await self.roi_head.async_simple_test(
                x, proposal_list, img_meta, rescale=rescale)
    
        def simple_test(self, img, img_metas, proposals=None, rescale=False):
            """Test without augmentation."""
            assert self.with_bbox, 'Bbox head must be implemented.'
    
            x = self.extract_feat(img)
    
            if proposals is None:
                proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
            else:
                proposal_list = proposals
    
            return self.roi_head.simple_test(
                x, proposal_list, img_metas, rescale=rescale)
    
        def aug_test(self, imgs, img_metas, rescale=False):
            """Test with augmentations.
    
            If rescale is False, then returned bboxes and masks will fit the scale
            of imgs[0].
            """
            x = self.extract_feats(imgs)
            proposal_list = self.rpn_head.aug_test_rpn(x, img_metas)
            return self.roi_head.aug_test(
                x, proposal_list, img_metas, rescale=rescale)

    TwoStageDetector继承自BaseDetector类,主要有以下函数:

    (1) __init__():初始化函数,主要是对backbone,neck,rpn_head,roi_head等进行设置;

    (2) init_weights():初始化参数值,包括,neck,rpn_head,roi_head的参数进行初始化;

    (3) extract_feat():把输入的图像数据送入neck,并且得到输出的特征图;

    (4) forward_train():主要有img,gt_bboxes,gt_labels几个参数,训练时的前向输出,

                rpn_losses, proposal_list = self.rpn_head.forward_train(
                    x,
                    img_metas,
                    gt_bboxes,
                    gt_labels=None,
                    gt_bboxes_ignore=gt_bboxes_ignore,
                    proposal_cfg=proposal_cfg)

    通过调用rpn_head的forward_train()函数计算RPN的损失函数值,并且得到proposal的列表proposal_list,

            roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
                                                     gt_bboxes, gt_labels,
                                                     gt_bboxes_ignore, gt_masks,
                                                     **kwargs)

    通过调用roi_head的forward_train()函数计算ROI的损失函数值.

  • 相关阅读:
    java常用问题排查工具
    一次CMS GC问题排查过程(理解原理+读懂GC日志)
    nginx [alert] 12339#0: 1024 worker_connections are not enough
    netstat Recv-Q和Send-Q
    Use of Recv-Q and Send-Q
    LoadRunner 11 error:Cannot initialize driver dll
    perf + Flame Graph火焰图分析程序性能
    nginx 499状态码
    supervisor管理nginx
    supervisor管理php-fpm
  • 原文地址:https://www.cnblogs.com/mstk/p/14944917.html
Copyright © 2020-2023  润新知