• mmdet阅读笔记


    mmdet

    后续陆续增加源码注释

    -- mmdetection.configs
    注意: _base_里面的文件都是基础的配置,后面的配置文件调用之后可以修改,以后面的为准
    configs/base/dataset: 基础数据的配置文件
    configs/base/models: 基础模型的配置文件
    configs/base/schedules: 基础超参数的配置文件
    configs/base/default_runtime.py: 基础实时配置文件,包括:模型保存间隔,dist后端配置....etc
    configs/others: 上层配置文件,调用base里面的配置,然后针对不同模型不同情况重新封装,实际调用以这个配置参数为准,基础只是通用配置。
    -- mmdetection.demo
    /demo/all: 主要是前向计算测试文件
    -- mmdetection.mmdet
    /mmdet/apis: 训练和前向计算实例化
    /mmdet/core: anchor和bbox等操作具体实现,并被包裹到registry
    /mmdet/datasets: 数据读取处理函数

    /datasets/pipelines: 数据增强具体实现和Compose
    /datasets/samplers:
    -- distributed_sampler.py: 重写了distributed_sampler类,和torch原版一点没变,仅仅改了名字。
    -- group_sampler.py:

    class GroupSampler(Sampler):
        # samples_per_gpu: 使用的GPU数量
        def __init__(self, dataset, samples_per_gpu=3):
            assert hasattr(dataset, 'flag') # 数据中的变量,用来分配类别,在datasets/cumtom.py定义
            self.dataset = dataset
            self.samples_per_gpu = samples_per_gpu
            self.flag = dataset.flag.astype(np.int64)
            self.group_sizes = np.bincount(self.flag)
            self.num_samples = 0
            for i, size in enumerate(self.group_sizes):
                self.num_samples += int(np.ceil(
                    size / self.samples_per_gpu)) * self.samples_per_gpu # 不是整数取最大值
    
        def __iter__(self):
            indices = []
            for i, size in enumerate(self.group_sizes):
                if size == 0:
                    continue
                indice = np.where(self.flag == i)[0]
                assert len(indice) == size
                np.random.shuffle(indice) # random sample
                num_extra = int(np.ceil(size / self.samples_per_gpu)
                                ) * self.samples_per_gpu - len(indice) # 不能整除的额外数据 数量
                indice = np.concatenate(
                    [indice, np.random.choice(indice, num_extra)]) # 不能整除的额外数据 使用前面数据随机取出的数补充
                indices.append(indice)
            indices = np.concatenate(indices)
            indices = [
                indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
                for i in np.random.permutation(
                    range(len(indices) // self.samples_per_gpu)) # 分配到每个GPU
            ]
            indices = np.concatenate(indices)
            indices = indices.astype(np.int64).tolist()
            assert len(indices) == self.num_samples
            return iter(indices)
    

    /torch/utils/data/dataset:

    class ConcatDataset(Dataset):
    
        def __init__(self, datasets):
            self.cumulative_sizes = self.cumsum(self.datasets) # 叠加长度总和[len_1, len_1+len_2, len_1+len_2+len_3]
    
        def __len__(self):
            return self.cumulative_sizes[-1]#总长度
    
        def __getitem__(self, idx):
            # 反向索引 
            if idx < 0:
                if -idx > len(self):
                    raise ValueError("absolute value of index should not exceed dataset length")
                idx = len(self) + idx
            # 二分查找子数据集
            dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
            if dataset_idx == 0:
                sample_idx = idx
            else:
                sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
            return self.datasets[dataset_idx][sample_idx] # 获得 指定子数据集 的 指定位置数据
    
        # 老版本名字已更改,可以更改数据集长度
        def cummulative_sizes(self):
            warnings.warn("cummulative_sizes attribute is renamed to "
                          "cumulative_sizes", DeprecationWarning, stacklevel=2)
            return self.cumulative_sizes
    

    /datasets/builder: 实例化数据相关任务:sample、dataloader、dataset
    /datasets/dataset_wrappers.py: 重写concatDataset、RepeatDataset上面已经详细说明,增加数据类别平衡类(具体没看)
    /datasets/custom.py:

    @DATASETS.register_module()
    class CustomDataset(Dataset):
        
        CLASSES = None #种类名称,可以直接定义(常用直接类内定义),也可以外部传入
        
        # 读取全部标签,格式如下:
        ‘’‘
        {
                    'filename': 'a.jpg',
                    'width': 1280,
                    'height': 720,
                    'ann': 
                        {
                            'bboxes': <np.ndarray> (n, 4),
                            'labels': <np.ndarray> (n, ),
                            'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
                            'labels_ignore': <np.ndarray> (k, 4) (optional field)
                        }
        },
        ’‘’
        def load_annotations(self, ann_file):
            pass
        
        # 暂不确定用途
        def load_proposals(self, proposal_file):
            pass
        
        # 过滤不符合条件数据
        def _filter_imgs(self, min_size=32):
            pass
        
        # 获取单个train数据
        def prepare_train_img(self, idx):
            pass
            
        # 获取单个test数据
        def prepare_test_img(self, idx):
        
        # 获得单个图像标注信息
        def get_ann_info(self, idx):
            pass
        
        # 随机选择数据,会使用_set_group_flag
        def _rand_another(self, idx):
            pass
        
        # 按特定格式给图像分类(原始使用长宽比)
        def _set_group_flag(self):
            pass
    

    整个数据读取流程比较清晰:

    graph TD A_1[准备特定格式label] --> A_2 A_2[读取全部label] --> A_3(过滤不合适label) A_3 --> C{train/test} C -->|train | D[读取图像信息+label信息] C -->|test| E[和train类似] D --> D_1{合适/不合适} D_1 --> |不合适| D_2(随机选取) D_1 --> |合适| D_3(直接选取)

    /mmdet/models: 模型实际实现函数
    /mmdet.ops: 需要快速实现的操作,如:NMS、ROIPooling、ROIAlign....
    /mmdet/utils: 一些辅助操作,环境变量和版本等
    -- mmdetection.tests
    /tests/all: 测试脚本,可以用来查看原理和测试
    -- mmdetection.tools
    /tools/all: 杂七杂八文件,包括:训练+测试(仅是入口,实际操作在apis之内),数据转换、计算MAC、转换模型ONNX.....
    /tools/train.py: 单机单卡
    /tools/dist_train.py: 单机单多卡,使用distribution
    /tools/slurm_train.py: 多机多卡

    大致流程:

    1. 准备数据集,在mmdet/datasets
    2. 准备模型,在mmdet/models, loss函数在models里面实现
    3. 准备特殊函数,在/mmdet/core,一些mmdet没有的操作
    4. 配置参数,在/configs, 基础配置可选,后面的参数必须配置
    5. 训练模型,在/mmdet/tools, 调用评估可在配置里设置
    6. 前向推理,在/demo
    Already open...

    ...

  • 相关阅读:
    IE11和传统asp.net的兼容问题
    时区和夏令时
    GTA项目 三, 使用 bootstrap table展示界面,使得data和UI分离
    GTA项目 二, JSON接口开放跨域访问
    GTA项目 一, 包装外部WebService
    DNS域名解析
    CRM 迁移服务器备忘
    CentOS6.5 安装HAProxy 1.5.20
    Custom IFormatProvider
    数据分区考虑
  • 原文地址:https://www.cnblogs.com/wjy-lulu/p/13268321.html
Copyright © 2020-2023  润新知