• 使用TensorFlow中的Batch Normalization


    问题

    训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题,例如在训练过程中每一层输入数据分布发生了改变了,那么我们就需要使用更小的learning rate去训练,这一现象被称为internal covariate shiftBatch Normalization能够很好的解决这一问题。目前该算法已经被广泛应用在深度学习模型中,该算法的强大至于在于:

    • 可以选择一个较大的学习率,能够达到快速收敛的效果。
    • 能够起到Regularizer的效果,在一些情况下可以不使用Dropout,因为BN提高了模型的泛化能力

    介绍

    我们在将数据输入到神经网络中往往需要对数据进行归一化,原因在于模型的目的就是为了学习模型的数据的分布,如果训练集的数据分布和测试集的不一样那么模型的泛化能力就会很差,另一方面如果模型的每一 batch的数据分布都不一样,那么模型就需要去学习不同的分布,这样模型的训练速度会大大降低。
    BN是一个独立的步骤,被应用在激活函数之前,它简单地对输入进行零中心(zero-center)和归一化(normalize),然后使用两个新参数来缩放和移动结果(一个用于缩放,另一个用于缩放转移)。 换句话说,BN让模型学习最佳的尺度和 每层的输入的平均值。
    为了零中心和归一化数据的分布,BN需要去估算输入的mean和standard deviation,算法的计算过程如下:

    其中:

    • (u_B)是mini-btach (B)的均值,(sigma)是mini-btach的标准差
    • (m_B)是mini-batch中的样本
    • (hat{x}^{(i)}) 是zero-center和normalize后的输入
    • 公式4是一个线性变换,是对数据分布的重构,(z^{(i)})是算法对数据重构的output,(gamma)(eta)分别代表的是对数据的scaleshift,是我们需要学习的参数

    应用

    接下来我们就使用TensorFlow来实现带有BN的神经网络,步骤和前面讲到的很多一样,只是在输入激活函数之前多处理了一部而已,在TF中我们使用的实现是tf.layers.batch_normalization

    import tensorflow as tf
    
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("./") #自动下载数据到这个目录
    tf.reset_default_graph()
    n_inputs = 28 * 28
    n_hidden1 = 300
    n_hidden2 = 100
    n_outputs = 10
    X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
    y = tf.placeholder(tf.int64, shape=(None), name="y")
    training = tf.placeholder_with_default(False, shape=(), name='training')
    
    hidden1 = tf.layers.dense(X, n_hidden1, name="hidden1")
    bn1 = tf.layers.batch_normalization(hidden1, training=training, momentum=0.9)
    bn1_act = tf.nn.elu(bn1)
    
    hidden2 = tf.layers.dense(bn1_act, n_hidden2, name="hidden2")
    bn2 = tf.layers.batch_normalization(hidden2, training=training, momentum=0.9)
    bn2_act = tf.nn.elu(bn2)
    
    logits_before_bn = tf.layers.dense(bn2_act, n_outputs, name="outputs")
    logits = tf.layers.batch_normalization(logits_before_bn, training=training,
                                          momentum=0.9)
    
    with tf.name_scope("loss"):
        xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)#labels允许的数据类型有int32, int64
        loss = tf.reduce_mean(xentropy,name="loss")
    learning_rate = 0.01
    with tf.name_scope("train"):
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        training_op = optimizer.minimize(loss)
    with tf.name_scope("eval"):
        correct = tf.nn.in_top_k(logits,y,1) #取值最高的一位
        accuracy = tf.reduce_mean(tf.cast(correct,tf.float32)) #结果boolean转为0,1
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    
    n_epochs = 20
    batch_size = 200
    with tf.Session() as sess:
        init.run()
        for epoch in range(n_epochs):
            for iteration in range(mnist.train.num_examples // batch_size):
                X_batch, y_batch = mnist.train.next_batch(batch_size)
                sess.run([training_op, extra_update_ops],
                        feed_dict={training: True, X: X_batch, y: y_batch})
            accuracy_val = accuracy.eval(feed_dict={X: mnist.test.images,
                                                    y: mnist.test.labels})
            print(epoch, "Test accuracy:", accuracy_val)
    
    

    在上面代码中有一句需要解释一下

    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    

    这是因为在计算BN中需要计算moving_meanmoving_variance并且更新,所以在执行run的时候需要将其添加到执行列表中。我们还可以这样写

    with tf.name_scope("train"):
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            training_op = optimizer.minimize(loss)
    

    在训练的时候就只需要更新一个参数

    sess.run(training_op, feed_dict={training: True, X: X_batch, y: y_batch})
    

    此外,我们会发现在编写神经网络代码中,很多代码都是重复的可以将其模块化,例如将构建每一层神经网络的代码封装成一个function,不过这都是后话,看个人喜好吧。

  • 相关阅读:
    Java学习62
    Java学习61
    Maven3种打包方式之一maven-assembly-plugin的使用
    sftp 上传下载 命令介绍
    JMock+Junit4结合完成TDD实例
    UML类图中类与类的四种关系图解
    接口之间的多继承
    Linux中在当前目录下查找某个文件
    .gitignore与exclude
    pro git
  • 原文地址:https://www.cnblogs.com/wxshi/p/8317489.html
Copyright © 2020-2023  润新知