• MMDetection源码解析:Faster RCNN(6)--SingleRoIExtractor类和BaseRoIExtractor类


    SingleRoIExtractor类定义在mmdetmodels oi_heads oi_extractorssingle_level_roi_extractor.py中,其作用是对ROI特征层进行特征提取,继承自BaseRoIExtractor类.

    import torch
    from mmcv.runner import force_fp32
    
    from mmdet.models.builder import ROI_EXTRACTORS
    from .base_roi_extractor import BaseRoIExtractor
    
    
    @ROI_EXTRACTORS.register_module()
    class SingleRoIExtractor(BaseRoIExtractor):
        """Extract RoI features from a single level feature map.
    
        If there are multiple input feature levels, each RoI is mapped to a level
        according to its scale. The mapping rule is proposed in
        `FPN <https://arxiv.org/abs/1612.03144>`_.
    
        Args:
            roi_layer (dict): Specify RoI layer type and arguments.
            out_channels (int): Output channels of RoI layers.
            featmap_strides (int): Strides of input feature maps.
            finest_scale (int): Scale threshold of mapping to level 0. Default: 56.
        """
    
        def __init__(self,
                     roi_layer,
                     out_channels,
                     featmap_strides,
                     finest_scale=56):
            super(SingleRoIExtractor, self).__init__(roi_layer, out_channels,
                                                     featmap_strides)
            self.finest_scale = finest_scale
    
        def map_roi_levels(self, rois, num_levels):
            """Map rois to corresponding feature levels by scales.
    
            - scale < finest_scale * 2: level 0
            - finest_scale * 2 <= scale < finest_scale * 4: level 1
            - finest_scale * 4 <= scale < finest_scale * 8: level 2
            - scale >= finest_scale * 8: level 3
    
            Args:
                rois (Tensor): Input RoIs, shape (k, 5).
                num_levels (int): Total level number.
    
            Returns:
                Tensor: Level index (0-based) of each RoI, shape (k, )
            """
            scale = torch.sqrt(
                (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
            target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
            target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
            return target_lvls
    
        @force_fp32(apply_to=('feats', ), out_fp16=True)
        def forward(self, feats, rois, roi_scale_factor=None):
            """Forward function."""
            out_size = self.roi_layers[0].output_size
            num_levels = len(feats)
            roi_feats = feats[0].new_zeros(
                rois.size(0), self.out_channels, *out_size)
            # TODO: remove this when parrots supports
            if torch.__version__ == 'parrots':
                roi_feats.requires_grad = True
    
            if num_levels == 1:
                if len(rois) == 0:
                    return roi_feats
                return self.roi_layers[0](feats[0], rois)
    
            target_lvls = self.map_roi_levels(rois, num_levels)
            if roi_scale_factor is not None:
                rois = self.roi_rescale(rois, roi_scale_factor)
            for i in range(num_levels):
                inds = target_lvls == i
                if inds.any():
                    rois_ = rois[inds, :]
                    roi_feats_t = self.roi_layers[i](feats[i], rois_)
                    roi_feats[inds] = roi_feats_t
                else:
                    roi_feats += sum(
                        x.view(-1)[0]
                        for x in self.parameters()) * 0. + feats[i].sum() * 0.
            return roi_feats

    主要的函数有:

    (1) __init__():初始化函数,设置finest_scale的值,作为分配样本到FPN哪一层的依据;

    (2)map_roi_levels():分配样本到FPN的某一层,根据论文的公式计算;

    (3)forward():前向传播,有feats,rois等几个参数,feats即FPN的几个特征层,rois即ROI的坐标.

    SingleRoIExtractor类的函数较少,很多功能是通过调用BaseRoIExtractor类的方法来实现,BaseRoIExtractor类定义在mmdetmodels oi_heads oi_extractorsase_roi_extractor.py中:

    from abc import ABCMeta, abstractmethod
    
    import torch
    import torch.nn as nn
    from mmcv import ops
    
    
    class BaseRoIExtractor(nn.Module, metaclass=ABCMeta):
        """Base class for RoI extractor.
    
        Args:
            roi_layer (dict): Specify RoI layer type and arguments.
            out_channels (int): Output channels of RoI layers.
            featmap_strides (int): Strides of input feature maps.
        """
    
        def __init__(self, roi_layer, out_channels, featmap_strides):
            super(BaseRoIExtractor, self).__init__()
            self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
            self.out_channels = out_channels
            self.featmap_strides = featmap_strides
            self.fp16_enabled = False
    
        @property
        def num_inputs(self):
            """int: Number of input feature maps."""
            return len(self.featmap_strides)
    
        def init_weights(self):
            pass
    
        def build_roi_layers(self, layer_cfg, featmap_strides):
            """Build RoI operator to extract feature from each level feature map.
    
            Args:
                layer_cfg (dict): Dictionary to construct and config RoI layer
                    operation. Options are modules under ``mmcv/ops`` such as
                    ``RoIAlign``.
                featmap_strides (int): The stride of input feature map w.r.t to the
                    original image size, which would be used to scale RoI
                    coordinate (original image coordinate system) to feature
                    coordinate system.
    
            Returns:
                nn.ModuleList: The RoI extractor modules for each level feature
                    map.
            """
    
            cfg = layer_cfg.copy()
            layer_type = cfg.pop('type')
            assert hasattr(ops, layer_type)
            layer_cls = getattr(ops, layer_type)
            roi_layers = nn.ModuleList(
                [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
            return roi_layers
    
        def roi_rescale(self, rois, scale_factor):
            """Scale RoI coordinates by scale factor.
    
            Args:
                rois (torch.Tensor): RoI (Region of Interest), shape (n, 5)
                scale_factor (float): Scale factor that RoI will be multiplied by.
    
            Returns:
                torch.Tensor: Scaled RoI.
            """
    
            cx = (rois[:, 1] + rois[:, 3]) * 0.5
            cy = (rois[:, 2] + rois[:, 4]) * 0.5
            w = rois[:, 3] - rois[:, 1]
            h = rois[:, 4] - rois[:, 2]
            new_w = w * scale_factor
            new_h = h * scale_factor
            x1 = cx - new_w * 0.5
            x2 = cx + new_w * 0.5
            y1 = cy - new_h * 0.5
            y2 = cy + new_h * 0.5
            new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
            return new_rois
    
        @abstractmethod
        def forward(self, feats, rois, roi_scale_factor=None):
            pass

    主要的函数有:

    (1) __init__():初始化函数,有roi_layer, out_channels, featmap_strides等几个参数,通过build_roi_layers()函数构造ROI层;

    (2) build_roi_layers():构造ROI层,有layer_cfg,featmap_strides等几个参数,通过nn.ModuleList()构造出几个ROI层;

    (3) roi_rescale():ROI的缩放;

    (4) forward():抽象方法,由子类实现.

  • 相关阅读:
    Java8集合框架——集合工具类Collections内部方法浅析
    Java8集合框架——LinkedHashSet源码分析
    Java8集合框架——HashSet源码分析
    Java8集合框架——LinkedHashMap源码分析
    Spring源码分析(001)——环境搭建
    SpringBoot2(007):关于Spring beans、依赖注入 和 @SpringBootApplication 注解
    SpringBoot2(006):关于配置类(Configuration Classes)和自动配置(Auto-configuration)
    SpringBoot2(005):关于工程代码结构的建议
    SpringBoot2(004):关于 Build Systems (构建系统)
    html中的dl,dt,dd标签
  • 原文地址:https://www.cnblogs.com/mstk/p/15030326.html
Copyright © 2020-2023  润新知