• RFS


    RFS

    RFS的策略,对类别c,首先有一个统计量(f_{c})。它的含义是,统计出那些起码包含一个类别c实例的图片所占的所有图片的频率。然后通过公式计算(r_{c} = max(1, sqrt(t / f_{c})))。在这里,t是一个超参数。一般来讲,t = 0.001。

    计算出(f_c)之后,我们需要做的事情是,对每一张图片i而言,(r_i = max_{i in c}r_c)。这样的话,在每一个epoch里面,SGD data sampler都是造一个random的permutation。每一张图片重复的次数都是(r_i)次。

    关于RFS的原理是没有一个准则的,它仅仅是一个启发式的算法(heuristic)。当一个instance的(f_{c})减小(lambda)倍,那么它的重复次数就会被增大(sqrt(1 / lambda))倍。

    下面是它的代码解析,其中包括了一些distributed的知识。

    import itertools
    import math
    from collections import defaultdict
    from typing import Optional
    import torch
    from torch.utils.data.sampler import Sampler
    
    from detectron2.utils import comm
    
    class TrainingSampler(Sampler):
        """
        在训练的时候,我们仅仅关心训练数据的"infinite stream",这个组件实现的就是一个"infinite"的下标流。
        这需要所有的wokers进行合作,才能正确的洗牌所有的下标,从而进一步实现采样不同的下标。
        每一个worker里面的sampler都能够很快的产生'indices[worker_id::num_workers]'。这个地方的indices就是一个无限
        流,包含的是'shuffle(range(size)) + shuffle(range(size) + ...' 如果shuffle参数设置为True
        否则就是'range(size) + range(size) + ...' 如果shuffle被设置为False.
        """
    	def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
            """值得注意的是,这里的seed指的是shuffle的初始种子,在所有的worker之间必须相同
            	如果这个参数被设置未None的话,那么将会random一个数值出来,并且在所有的worker之间进行共享。
            """
    		self._size = size
            assert size > 0
            self._shuffle = shuffle
            if seed is None:
                seed = comm.shared_random_seed()
            self._seed = int(seed)
    
            self._rank = comm.get_rank()
            self._world_size = comm.get_world_size()
        
    	def __iter__(self): # 返回值是生成器
            start = self._rank
            """
            	这里介绍一下itertools.islice。islice(iterable, [start,] stop [, step])
            	输入是一个迭代器或者生成器。注意这个函数会消耗迭代器(毕竟迭代器只能用一次)
            	返回结果是一个迭代器,它可以产生所需要的切片元素。丢弃所有启始索引之前的元素,知道到达所有结束索引为止
            	注意迭代器里面的元素只能够访问一次。
            	None代表无上限。
            	yield from 指的就是返回另一个生成器。后面只要是一个可迭代的就可以了。
            	yield from iterable <=> for i in iterable: yield i
            """
            yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) # 分到的indices并不连续
    
        def _infinite_indices(self): #本函数返回的是一个生成器
            g = torch.Generator()
            g.manual_seed(self._seed)
            while True:
                if self._shuffle:
                    yield from torch.randperm(self._size, generator=g).tolist()
                else:
                    yield from torch.arange(self._size).tolist()
    
     class RepeatFactorTrainingSampler(Sampler):
        """
        某一个样本会根据它的repeat_factors,有可能会出现多次。
        """
    
        def __init__(self, repeat_factors, *, shuffle=True, seed=None):
            """
            Args:
                repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's
                    full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``.
                shuffle (bool): whether to shuffle the indices or not
                seed (int): the initial seed of the shuffle. Must be the same
                    across all workers. If None, will use a random seed shared
                    among workers (require synchronization among all workers).
            """
            self._shuffle = shuffle
            if seed is None:
                seed = comm.shared_random_seed()
            self._seed = int(seed)
    
            self._rank = comm.get_rank()
            self._world_size = comm.get_world_size()
    
            # Split into whole number (_int_part) and fractional (_frac_part) parts. 分成整数部分与小数部分。
            self._int_part = torch.trunc(repeat_factors)
            self._frac_part = repeat_factors - self._int_part
    
        @staticmethod
        def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh):
            """
            看上面我的解释
            Args:
                dataset_dicts (list[dict]): d2的注释格式
                repeat_thresh (float): 在threshhold之下的类别是应该重复的。
    
            Returns:
                每一张图片的重复次数。
            """
            # 对每一个类来讲,要算出它的frequency
            category_freq = defaultdict(int)
            for dataset_dict in dataset_dicts:  #对每一张图片
                cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
                for cat_id in cat_ids:
                    category_freq[cat_id] += 1
            num_images = len(dataset_dicts)
            for k, v in category_freq.items():
                category_freq[k] = v / num_images
    
            # 2. For each category c, compute the category-level repeat factor:
            #    r(c) = max(1, sqrt(t / f(c)))
            category_rep = {
                cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
                for cat_id, cat_freq in category_freq.items()
            }
    
            # 3. 返回的一个tensor,里面每一个值对应的是每一个index对应的repeat次数。
            #    r(I) = max_{c in I} r(c)
            rep_factors = []
            for dataset_dict in dataset_dicts:
                cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
                rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
                rep_factors.append(rep_factor)
    
            return torch.tensor(rep_factors, dtype=torch.float32)
            
    	def _get_epoch_indices(self, generator):
            """
            Create a list of dataset indices (with repeats) to use for one epoch.
    
            Args:
                generator (torch.Generator): 随机数生成器,用来概率取整。
    
            Returns:
                返回值是indices,tensor类型。每一张图片的indices是经过重复的。
            """
            # 这里说明一件事情就是,由于reapeat factor是一个小数,所以我们
            # 只能够采用一个带有概率的上取整或者下取整。这样能在期望意义上,
            # 等价的。
            rands = torch.rand(len(self._frac_part), generator=generator)
            rep_factors = self._int_part + (rands < self._frac_part).float()
            # Construct a list of indices in which we repeat images as specified
            indices = []
            for dataset_index, rep_factor in enumerate(rep_factors):
                indices.extend([dataset_index] * int(rep_factor.item()))
            return torch.tensor(indices, dtype=torch.int64)
    
        def __iter__(self):
            start = self._rank
            yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
    
        def _infinite_indices(self):
            g = torch.Generator()
            g.manual_seed(self._seed)
            while True:
                # Sample indices with repeats determined by stochastic rounding; each
                # "epoch" may have a slightly different size due to the rounding.
                indices = self._get_epoch_indices(g)
                if self._shuffle:
                    randperm = torch.randperm(len(indices), generator=g)
                    yield from indices[randperm].tolist()
                else:
                    yield from indices.tolist()
                   
    class InferenceSampler(Sampler):
        """
        Produce indices for inference across all workers.
        Inference needs to run on the __exact__ set of samples,
        therefore when the total number of samples is not divisible by the number of workers,
        this sampler produces different number of samples on different workers.
        """
    
        def __init__(self, size: int):
            """
            Args:
                size (int): the total number of data of the underlying dataset to sample from
            """
            self._size = size
            assert size > 0
            self._rank = comm.get_rank()
            self._world_size = comm.get_world_size()
    
            shard_size = (self._size - 1) // self._world_size + 1
            begin = shard_size * self._rank
            end = min(shard_size * (self._rank + 1), self._size)
            self._local_indices = range(begin, end)
    
    
        def __iter__(self):
            yield from self._local_indices
    
        def __len__(self):
            return len(self._local_indices)
    
  • 相关阅读:
    数据结构——单链表(singly linked list)
    Java——判断回文
    C——swap
    Java动态数组
    mui框架下监听返回按钮
    Ubuntu 18.04版本下安装网易云音乐
    Linux安装Broadcom无线驱动
    EFI环境下的Ubuntu&Win10双系统安装
    Leaflet中添加的不同图层样式图标
    数据插入数据库时,提示表名不存在
  • 原文地址:https://www.cnblogs.com/JohnRan/p/15098402.html
Copyright © 2020-2023  润新知