• TensorFlow学习笔记之批归一化:tf.layers.batch_normalization()函数


    关于归一化的讲解的博客——【深度学习】Batch Normalization(批归一化)

    tensorflow中的函数解析在这个博客——tf.nn.batch_normalization()函数解析

    0 前言

    关于批归一化的讲解我们在博客【深度学习】Batch Normalization(批归一化)中做了详细的讲解,不懂的同学可以看一下,下面我们来说一种TensorFlow中的批归一化的代码实现,主要使用的函数是tf.layers.batch_normalization()函数。

    简单来说公式如下:

    y=γ(xμ)/σ+βy=γ(x-μ)/σ+β

    其中xx是输入,yy是输出,μμ是均值,σσ是方差,γγββ是缩放(scale)、偏移(offset)系数。

    1 函数

    tf.layers.batch_normalization(
        inputs,
        axis=-1,
        momentum=0.99,
        epsilon=0.001,
        center=True,
        scale=True,
        beta_initializer=tf.zeros_initializer(),
        gamma_initializer=tf.ones_initializer(),
        moving_mean_initializer=tf.zeros_initializer(),
        moving_variance_initializer=tf.ones_initializer(),
        beta_regularizer=None,
        gamma_regularizer=None,
        beta_constraint=None,
        gamma_constraint=None,
        training=False,
        trainable=True,
        name=None,
        reuse=None,
        renorm=False,
        renorm_clipping=None,
        renorm_momentum=0.99,
        fused=None,
        virtual_batch_size=None,
        adjustment=None
    )
    

    注意:训练时,需要更新滑动平均值(moving_mean)和滑动方差(moving_variance)。默认情况下,更新操作放置在tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为对train_ops的依赖项添加。此外,在获取update_ops集合之前,请确保添加任何批处理标准化(batch_normalization)操作。否则,update_ops将为空,训练 / 验证将无法正常工作。

    参数:

    • inputs:张量输入。
    • axis:一个int,应该被规范化的轴(通常是特征轴)。例如,在使用data_format=“channels_first”Convolution2D层之后,在BatchNormalization中设置axis=1
    • momentum:滑动平均值的动量。
    • epsilon:小浮点数加上方差以避免被零除。
    • center:如果为True,则将beta的偏移量添加到标准化张量。如果为False,则忽略beta
    • scale:如果为True,则乘以gamma。如果为False,则不使用gamma。当下一层是线性的(例如,nn.relu)时,可以禁用此选项,因为可以由下一层进行缩放。
    • beta_initializerbeta权重的初始值设定项。
    • gamma_initializergamma权重的初始值设定项。
    • moving_mean_initializer:滑动平均值的初始化器。
    • moving_variance_initializer:滑动方差的初始值设定项。
    • beta_regularizer:可选的beta权重正则化器。
    • gamma_regularizergamma权重的可选调节器。
    • beta_constraint:由Optimizer更新后应用于beta权重的可选投影函数(例如,用于实现层权重的规范约束或值约束)。函数必须将未投影的变量作为输入,并且必须返回投影的变量(必须具有相同的形状)。在进行异步分布式训练时,使用约束是不安全的。
    • gamma_constraint:由Optimizer更新后应用于gamma权重的可选投影函数。
    • training:要么是Python布尔值,要么是TensorFlow布尔值标量张量(例如占位符)。是以训练模式(使用当前批的统计数据进行规范化)还是以推理模式(使用滑动统计数据进行规范化)返回输出。注意:请确保正确设置此参数,否则您的训练 / 验证将无法正常工作。
    • trainable:布尔值,如果为True,还将变量添加到图形集合GraphKeys.TRAINABLE_VARIABLES(请参见tf.variable)。
    • name:字符串,层的名称。
    • reuse:布尔值,是否以相同的名称重用前一层的权重。
    • `renorm:是否使用批量再规范化(https://arxiv.org/abs/1702.03275)。 这会在训练期间增加额外的变量。这个参数的任何一个值的推断都是相同的。
    • renorm_clipping:一种字典,可以将关键字“rmax”、“rmin”、“dmax”映射到用于剪裁renorm校正的标量张量。校正(r,d)用作corrected_value = normalized_value * r + d,其中r被剪裁为[rmin,rmax]d被剪裁为[-dmax,dmax]。缺少的rmax、rmin和dmax分别设置为inf、0和inf。
    • renorm_momentum:用renorm更新滑动方式和标准偏差的动量。与动量不同,这会影响训练,既不应太小(会增加噪音),也不应太大(会给出过时的估计)。注意,momentum仍然被用来得到均值和方差来进行推理。
    • fused:如果False或者True,尽可能使用更快的融合实现。如果为False,则使用系统建议的实现。
    • virtual_batch_size:一个int。默认情况下,virtual_batch_sizeNone,这意味着在整个批次中执行批次规范化。当virtual_batch_size不是None时,改为执行“Ghost Batch Normalization”,创建每个单独规范化的虚拟子批(使用共享gamma、beta和滑动统计)。必须在执行期间划分实际批大小。
    • adjustment:仅在训练期间,采用包含输入张量(动态)形状的张量并返回一对(scalebias)以应用于标准化值(γβ之前)的函数。例如,如果axis=-1adjustment = lambda shape: ( tf.random_uniform(shape[-1:], 0.93, 1.07)tf.random_uniform(shape[-1:], -0.1, 0.1))将标准化值向上或向下缩放7%,然后将结果向上滑动0.1(每个功能都有独立的缩放和偏移,但在所有示例中都有共享),最后应用gamma 和/或 beta。如果没有,则不应用调整。如果指定了virtual_batch_size,则无法指定。

    返回:

    • 输出张量。

    更多详细的内容可以去官方API文档查看:https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/layers/batch_normalization

    2 训练

    训练的时候需要特别注意的有两点:

    (1)输入参数设置training=True

    (2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。

    这样才能计算μσ的滑动平均(测试时会用到)

    net_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = net_optimizer.minimize(loss)
    

    3 验证

    测试的时候需要特别注意的只有一点:输入参数设置training=False

    训练时:

    我们需要逐个神经元逐个样本地来计算,这个batch在某一层输出的均值和标准差,然后再对该层的输出进行标准化,同时还要学习gamma和beta两个参数。这是非常耗时的,显然,我们不能在inference的时候使用这种方法。

    解决方案就是,在训练时使用滑动平均维护population均值和方差:

    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_std = momentum * running_std + (1 - momentum) * sample_std
    

    在训练结束保存模型时,running_meanrunning_vartrained_gammatrained_beta一同被保存下来。

    验证时:

    output = (input - running_mean) / np.sqrt(running_std+eps)
    output = trained_gamma * output + trained_beta
    

    也就是说,在验证时,BN对应的操作不再是公式里提到的那样,计算该batch的各种统计量,而是直接使用在训练时保存下来的population均值和方差,进行一次线性变换。这样效率提升了很多。但是缺点也显而易见,如果训练集和验证集不平衡的时候,验证的效果会一直一直很差,所以这样看来深度学习还是数据的游戏,谁的数据质量高谁就是赢家罢了。

    4 测试

    测试是通过读取我们保存的模型,也就是从.checkpoint文件中读取模型参数从而进行测试。

    在这里插入图片描述
    我常用的保存参数的方式是最多保存几个

    var_list = tf.trainable_variables()
    saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
    

    5 批归一化

    在批归一化中有两个重要的参数γβ,这是可训练参数,而μσ不是,它们是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到μσ

    var_list = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars
    

    按照上述写法,即可把μσ保存下来,读取模型预测时也不会报错。

    6 总结

    首先指定一个参数is_training,把is_training这个布尔值在构建计算图时就确定,从而更加清晰地调用batch normalization,确保不会出错。

    is_training = tf.placeholder(tf.bool)
    

    通过feed_dict进行is_training的赋值,再通过is_trainingtraining进行赋值。

    ...
    ...
    [train_batch_image, train_batch_label] = sess.run([train_images, train_labels])
                train_feed_dict = {x: train_batch_image,
                                   y_: train_batch_label,
                                   is_training: True}
    ...
    ...
    [val_batch_image, val_batch_label] = sess.run([val_images, val_labels])
                val_feed_dict = {x: val_batch_image,
                                 y_: val_batch_label,
                                 is_training: False}
    ...
    ...
    

    如果想要更多的资源,欢迎关注 @我是管小亮,文字强迫症MAX~

    回复【福利】即可获取我为你准备的大礼,包括C++,编程四大件,NLP,深度学习等等的资料。

    想看更多文(段)章(子),欢迎关注微信公众号「程序员管小亮」~

    在这里插入图片描述

    参考文章

  • 相关阅读:
    day08作业
    day07作业
    day06作业
    day05作业
    OOAD与UML
    大数据(3):基于sogou.500w.utf8数据Hbase和Spark实践
    大数据(2):基于sogou.500w.utf8数据hive的实践
    大数据(1):基于sogou.500w.utf8数据的MapReduce程序设计
    九大排序算法的Java实现
    数字在排序数组中出现的次数
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13302755.html
Copyright © 2020-2023  润新知