• BN


    BN

    BN中有一些比较值得注意的地方:

    1. train/test不一致的好处与坏处
    2. 推理中的坑:移动平均。
    3. 训练中的坑:batch的大小与分布。
    4. 微调中的坑:参数化,数据分布等。
    5. 实现中的坑:一个多功能的BN的实现。
    6. GN,precise-BN等等改进。

    BN在训练和测试的时候,行为是不一致的。

    在训练的时候,BN是使用了EMA来进行更新的。在测试的时候,并不是采用了EMA,而是采用了训练时候的统计量。

    1. EMA在(lambda)过于小的时候,EMA并不是合理的近似。
    2. (lambda)过于大的时候,需要很多次迭代。
    3. 模型不稳定的时候,或者是数据不稳定的时候。可能造成一些问题。

    使用Precise-BatchNorm

    继续使用EMA,但是使用比较大的(lambda),把模型固定住。forward很多次迭代。

    Rethinking 'Batch' in batchnormalization这篇paper没怎么读。但是我读了一下precise BN的code:

    为了防止大家对里面的一些函数并不是很熟悉,所以。

    itertools.islice()表示对迭代器进行切片,并且会消耗迭代器。

    running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
    running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)

    这个其实很好理解。这个等价于先求和再取平均。

    #!/usr/bin/env python3
    # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
    
    import itertools
    
    import torch
    
    BN_MODULE_TYPES = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
    )
    
    
    @torch.no_grad()
    def update_bn_stats(model, data_loader, num_iters: int = 200):
        """
        Recompute and update the batch norm stats to make them more precise. During
        training both BN stats and the weight are changing after every iteration, so
        the running average can not precisely reflect the actual stats of the
        current model.
        In this function, the BN stats are recomputed with fixed weights, to make
        the running average more precise. Specifically, it computes the true average
        of per-batch mean/variance instead of the running average.
    
        Args:
            model (nn.Module): the model whose bn stats will be recomputed.
    
                Note that:
    
                1. This function will not alter the training mode of the given model.
                   Users are responsible for setting the layers that needs
                   precise-BN to training mode, prior to calling this function.
    
                2. Be careful if your models contain other stateful layers in
                   addition to BN, i.e. layers whose state can change in forward
                   iterations.  This function will alter their state. If you wish
                   them unchanged, you need to either pass in a submodule without
                   those layers, or backup the states.
            data_loader (iterator): an iterator. Produce data as inputs to the model.
            num_iters (int): number of iterations to compute the stats.
        """
        bn_layers = get_bn_modules(model)
    
        if len(bn_layers) == 0:
            return
    
        # In order to make the running stats only reflect the current batch, the
        # momentum is disabled.
        # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
        # Setting the momentum to 1.0 to compute the stats without momentum.
        momentum_actual = [bn.momentum for bn in bn_layers]
        for bn in bn_layers:
            bn.momentum = 1.0
    
        # Note that running_var actually means "running average of variance"
        running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
        running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
    
        for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
            model(inputs)
    
            for i, bn in enumerate(bn_layers):
                # Accumulates the bn stats.
                running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
                running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
                # We compute the "average of variance" across iterations.
        assert ind == num_iters - 1, (
            "update_bn_stats is meant to run for {} iterations, "
            "but the dataloader stops at {} iterations.".format(num_iters, ind)
        )
    
        for i, bn in enumerate(bn_layers):
            # Sets the precise bn stats.
            bn.running_mean = running_mean[i]
            bn.running_var = running_var[i]
            bn.momentum = momentum_actual[i]
    
    
    def get_bn_modules(model):
        """
        Find all BatchNorm (BN) modules that are in training mode. See
        cvpack2.modeling.nn_utils.precise_bn.BN_MODULE_TYPES for a list of all modules that are
        included in this search.
    
        Args:
            model (nn.Module): a model possibly containing BN modules.
    
        Returns:
            list[nn.Module]: all BN modules in the model.
        """
        # Finds all the bn layers.
        bn_layers = [
            m
            for m in model.modules()
            if m.training and isinstance(m, BN_MODULE_TYPES)
        ]
        return bn_layers
    
    
  • 相关阅读:
    安装rqalpha的日志
    从github上下载一个csv文件
    PyQt4 里的表格部件的使用方法: QTableWidget
    markdown里的多层次列表项
    打包python脚本为exe的坎坷经历, by pyinstaller方法
    Spyder docstrings文档字符串的标准
    Plot Candlestick Charts in Research of quantopian
    另类之将ipython notebook嵌入blog方法
    Jupyter Notebook Tutorial: Introduction, Setup, and Walkthrough
    爬虫视频讲座
  • 原文地址:https://www.cnblogs.com/JohnRan/p/15098398.html
Copyright © 2020-2023  润新知