• keras BatchNormalization 之坑


    任务简述:最近做一个图像分类的任务, 一开始拿vgg跑一个baseline,输出看起来很正常:

    随后,我尝试其他的一些经典的模型架构,比如resnet50, xception,但训练输出显示明显异常:

    val_loss 一直乱蹦,val_acc基本不发生变化。

    检查了输入数据没发现问题,因此怀疑是网络构造有问题, 对比了vgg同xception, resnet在使用layer上的异同,认为问题可能出在BN层上,将vgg添加了BN层之后再训练果然翻车。

    翻看keras BN 的源码, 原来keras 的BN层的call函数里面有个默认参数traing, 默认是None。此参数意义如下:

    training=False/0, 训练时通过每个batch的移动平均的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化

    training=True/1/None,训练时通过当前batch的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化

     当training=None时,训练和测试的批归一化方式不一致,导致validation的输出指标翻车。

    当training=True时,拿训练完的模型预测一个样本和预测一个batch的样本的差异非常大,也就是预测的结果根据batch的大小会不同!导致模型结果无法准确评估!也是个坑!

    用keras的BN时切记要设置training=False!!!

    def build_model():
        Inputs = Input(shape=intput_shape, name='input')
        x_tmp = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(Inputs)
        x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
        x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
        x_tmp = BatchNormalization(x_tmp, training=False)
        x_tmp = MaxPooling2D(pool_size=(2, 2))(x_tmp)
    
        x_tmp = Flatten()(x_tmp)
        x_tmp = Dense(128, activation='relu')(x_tmp)
        outputs = Dense(10, activation='softmax')(x_tmp)
        model = Model(Inputs, outputs)
        return model

    参考:

    https://arxiv.org/pdf/1502.03167v3.pdf

    https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py#L16

  • 相关阅读:
    spring学习10-AOP
    spring学习9-代理模式
    spring学习6-bean的自动装配
    PyQT5使用心得
    Python 时间戳和日期相互转换
    requests模块的入门使用
    Celery异步任务
    MySQL和python交互
    MySQL高级
    MySQL中select的使用
  • 原文地址:https://www.cnblogs.com/Fosen/p/11419930.html
Copyright © 2020-2023  润新知