• pytorch BatchNorm参数详解,计算过程


    https://blog.csdn.net/weixin_39228381/article/details/107896863

    目录

     

    说明

    BatchNorm1d参数

    num_features

    eps

    momentum

    affine

    track_running_stats

    BatchNorm1d训练时前向传播

    BatchNorm1d评估时前向传播

    总结


    说明

    网络训练时和网络评估时,BatchNorm模块的计算方式不同。如果一个网络里包含了BatchNorm,则在训练时需要先调用train(),使网络里的BatchNorm模块的training=True(默认是True),在网络评估时,需要先调用eval()使网络的training=False。

    BatchNorm1d参数

    torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

    num_features

    输入维度是(N, C, L)时,num_features应该取C;这里N是batch size,C是数据的channel,L是数据长度。

    输入维度是(N, L)时,num_features应该取L;这里N是batch size,L是数据长度,这时可以认为每条数据只有一个channel,省略了C

    eps

    对输入数据进行归一化时加在分母上,防止除零,详情见下文。

    momentum

    更新全局均值running_mean和方差running_var时使用该值进行平滑,详情见下文。

    affine

    设为True时,BatchNorm层才会学习参数gammaeta,否则不包含这两个变量,变量名是weight和bias,详情见下文。

    track_running_stats

    设为True时,BatchNorm层会统计全局均值running_mean和方差running_var,详情见下文。

    BatchNorm1d训练时前向传播

    1. 首先对输入batch求E[x]Var[x],并用这两个结果把batch归一化,使其均值为0,方差为1。归一化公式用到了eps(epsilon),即y=frac{x-E[x]}{sqrt{Var[x]+epsilon }}。如下输入内容,shape是(3, 4),即batch_size=3,此时num_features需要传入4。
      1.  
        tensor = torch.FloatTensor([[1, 2, 4, 1],
      2.  
        [6, 3, 2, 4],
      3.  
        [2, 4, 6, 1]])
      此时E[x]=[3, 3, 4, 2]Var[y]_{unbiased}=[7, 1, 4, 3](无偏样本方差)和Var[y]_{biased}=[4.6667, 0.6667, 2.6667, 2.0000](有偏样本方差),有偏和无偏的区别在于无偏的分母是N-1,有偏的分母是N。注意在BatchNorm中,用于更新running_var时,使用无偏样本方差即,但是在对batch进行归一化时,使用有偏样本方差,因此如果batch_size=1,会报错。归一化后的内容如下。
      1.  
        [[-0.9258, -1.2247, 0.0000, -0.7071],
      2.  
        [ 1.3887, 0.0000, -1.2247, 1.4142],
      3.  
        [-0.4629, 1.2247, 1.2247, -0.7071]]
    2. 如果track_running_stats==True,则使用momentum更新模块内部的running_mean(初值是[0., 0., 0., 0.])和running_var(初值是[1., 1., 1., 1.]),更新公式是x_{new}=(1-momentum)	imes x_{cur}+momentum	imes x_{batch},其中x_{new}代表更新后的running_mean和running_var,x_{cur}表示更新前的running_mean和running_var,x_{batch}表示当前batch的均值和无偏样本方差。
    3. 如果track_running_stats==False,则BatchNorm中不含有running_mean和running_var两个变量。
    4. 如果affine==True,则对归一化后的batch进行仿射变换,即乘以模块内部的weight(初值是[1., 1., 1., 1.])然后加上模块内部的bias(初值是[0., 0., 0., 0.]),这两个变量会在反向传播时得到更新。
    5. 如果affine==False,则BatchNorm中不含有weight和bias两个变量,什么都都不做。

    BatchNorm1d评估时前向传播

    1. 如果track_running_stats==True,则对batch进行归一化,公式为y=frac{x-hat{E}[x]}{sqrt{hat{Var}[x]+epsilon }},注意这里的均值和方差是running_mean和running_var,在网络训练时统计出来的全局均值和无偏样本方差。
    2. 如果track_running_stats==False,则对batch进行归一化,公式为y=frac{x-{E}[x]}{sqrt{​{Var}[x]+epsilon }},注意这里的均值和方差是batch自己的mean和var,此时BatchNorm里不含有running_mean和running_var。注意此时使用的是无偏样本方差(和训练时不同),因此如果batch_size=1,会使分母为0,就报错了。
    3. 如果affine==True,则对归一化后的batch进行放射变换,即乘以模块内部的weight然后加上模块内部的bias,这两个变量都是网络训练时学习到的。
    4. 如果affine==False,则BatchNorm中不含有weight和bias两个变量,什么都不做。

    总结

    在使用batchNorm时,通常只需要指定num_features就可以了。网络训练前调用train(),训练时BatchNorm模块会统计全局running_mean和running_var,学习weight和bias,即文献中的gammaeta。网络评估前调用eval(),评估时,对传入的batch,使用统计的全局running_mean和running_var对batch进行归一化,然后使用学习到的weight和bias进行仿射变换。

  • 相关阅读:
    对于开发WEB方面项目需要的工具和技术了解
    SQLServer创建链接服务器
    Tomcat部署Web应用方法总结
    JDK/bin目录下的不同exe文件的用途
    js高级技巧自定义事件
    HTML5 web SQL
    js高级技巧拖放
    图片替换文字
    CSS内容生成(重要内容:css计数器)
    CSS 使元素垂直居中
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/14007260.html
Copyright © 2020-2023  润新知