• 实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题


    实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题

    一、实践流程

    1、数据预处理

    主要是对训练数据进行随机偏移、转动等变换图像处理,这样可以尽可能让训练数据多样化

    另外处理数据方式采用分批无序读取的形式,避免了数据按目录排序训练

    1.  
      #数据准备
    2.  
      def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
    3.  
      if is_train:
    4.  
      datagen = ImageDataGenerator(rescale=1./255,
    5.  
      zoom_range=0.25, rotation_range=15.,
    6.  
      channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
    7.  
      horizontal_flip=True, fill_mode='constant')
    8.  
      else:
    9.  
      datagen = ImageDataGenerator(rescale=1./255)
    10.  
       
    11.  
      generator = datagen.flow_from_directory(
    12.  
      dir_path, target_size=(img_row, img_col),
    13.  
      batch_size=batch_size,
    14.  
      shuffle=is_train)
    15.  
       
    16.  
      return generator
    2、载入现有模型

    这个部分是核心工作,目的是使用ImageNet训练出的权重来做我们的特征提取器,注意这里后面的分类层去掉

    1.  
      base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
    2.  
      input_shape=(img_rows, img_cols, color),
    3.  
      classes=nb_classes)

    然后是冻结这些层,因为是训练好的

    1.  
      for layer in base_model.layers:
    2.  
      layer.trainable = False
    而分类部分,需要我们根据现有需求来新定义的,这里可以根据实际情况自己进行调整,比如这样
    1.  
      x = base_model.output
    2.  
      # 添加自己的全链接分类层
    3.  
      x = GlobalAveragePooling2D()(x)
    4.  
      x = Dense(1024, activation='relu')(x)
    5.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    或者
    1.  
      x = base_model.output
    2.  
      #添加自己的全链接分类层
    3.  
      x = Flatten()(x)
    4.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    3、训练模型

    这里我们用fit_generator函数,它可以避免了一次性加载大量的数据,并且生成器与模型将并行执行以提高效率。比如可以在CPU上进行实时的数据提升,同时在GPU上进行模型训练

    1.  
      history_ft = model.fit_generator(
    2.  
      train_generator,
    3.  
      steps_per_epoch=steps_per_epoch,
    4.  
      epochs=epochs,
    5.  
      validation_data=validation_generator,
    6.  
      validation_steps=validation_steps)

    二、猫狗大战数据集

    训练数据540M,测试数据270M,大家可以去官网下载

    https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

    下载后把数据分成dog和cat两个目录来存放

    三、训练

    训练的时候会自动去下权值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我们已经下载好了的话,可以改源代码,让他直接读取我们的下载好的权值,比如在resnet50.py中

    1、VGG19

    vgg19的深度有26层,参数达到了549M,原模型最后有3个全连接层做分类器所以我还是加了一个1024的全连接层,训练10轮的情况达到了89%

    2、ResNet50

    ResNet50的深度达到了168层,但是参数只有99M,分类模型我就简单点,一层直接分类,训练10轮的达到了96%的准确率

    3、inception_v3

    InceptionV3的深度159层,参数92M,训练10轮的结果

    这是一层直接分类的结果

    这是加了一个512全连接的,大家可以随意调整测试

    四、完整的代码

    1.  
      # -*- coding: utf-8 -*-
    2.  
      import os
    3.  
      from keras.utils import plot_model
    4.  
      from keras.applications.resnet50 import ResNet50
    5.  
      from keras.applications.vgg19 import VGG19
    6.  
      from keras.applications.inception_v3 import InceptionV3
    7.  
      from keras.layers import Dense,Flatten,GlobalAveragePooling2D
    8.  
      from keras.models import Model,load_model
    9.  
      from keras.optimizers import SGD
    10.  
      from keras.preprocessing.image import ImageDataGenerator
    11.  
      import matplotlib.pyplot as plt
    12.  
       
    13.  
      class PowerTransferMode:
    14.  
      #数据准备
    15.  
      def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
    16.  
      if is_train:
    17.  
      datagen = ImageDataGenerator(rescale=1./255,
    18.  
      zoom_range=0.25, rotation_range=15.,
    19.  
      channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
    20.  
      horizontal_flip=True, fill_mode='constant')
    21.  
      else:
    22.  
      datagen = ImageDataGenerator(rescale=1./255)
    23.  
       
    24.  
      generator = datagen.flow_from_directory(
    25.  
      dir_path, target_size=(img_row, img_col),
    26.  
      batch_size=batch_size,
    27.  
      #class_mode='binary',
    28.  
      shuffle=is_train)
    29.  
       
    30.  
      return generator
    31.  
       
    32.  
      #ResNet模型
    33.  
      def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
    34.  
      color = 3 if RGB else 1
    35.  
      base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
    36.  
      classes=nb_classes)
    37.  
       
    38.  
      #冻结base_model所有层,这样就可以正确获得bottleneck特征
    39.  
      for layer in base_model.layers:
    40.  
      layer.trainable = False
    41.  
       
    42.  
      x = base_model.output
    43.  
      #添加自己的全链接分类层
    44.  
      x = Flatten()(x)
    45.  
      #x = GlobalAveragePooling2D()(x)
    46.  
      #x = Dense(1024, activation='relu')(x)
    47.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    48.  
       
    49.  
      #训练模型
    50.  
      model = Model(inputs=base_model.input, outputs=predictions)
    51.  
      sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
    52.  
      model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    53.  
       
    54.  
      #绘制模型
    55.  
      if is_plot_model:
    56.  
      plot_model(model, to_file='resnet50_model.png',show_shapes=True)
    57.  
       
    58.  
      return model
    59.  
       
    60.  
       
    61.  
      #VGG模型
    62.  
      def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
    63.  
      color = 3 if RGB else 1
    64.  
      base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
    65.  
      classes=nb_classes)
    66.  
       
    67.  
      #冻结base_model所有层,这样就可以正确获得bottleneck特征
    68.  
      for layer in base_model.layers:
    69.  
      layer.trainable = False
    70.  
       
    71.  
      x = base_model.output
    72.  
      #添加自己的全链接分类层
    73.  
      x = GlobalAveragePooling2D()(x)
    74.  
      x = Dense(1024, activation='relu')(x)
    75.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    76.  
       
    77.  
      #训练模型
    78.  
      model = Model(inputs=base_model.input, outputs=predictions)
    79.  
      sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
    80.  
      model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    81.  
       
    82.  
      # 绘图
    83.  
      if is_plot_model:
    84.  
      plot_model(model, to_file='vgg19_model.png',show_shapes=True)
    85.  
       
    86.  
      return model
    87.  
       
    88.  
      # InceptionV3模型
    89.  
      def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,
    90.  
      is_plot_model=False):
    91.  
      color = 3 if RGB else 1
    92.  
      base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
    93.  
      input_shape=(img_rows, img_cols, color),
    94.  
      classes=nb_classes)
    95.  
       
    96.  
      # 冻结base_model所有层,这样就可以正确获得bottleneck特征
    97.  
      for layer in base_model.layers:
    98.  
      layer.trainable = False
    99.  
       
    100.  
      x = base_model.output
    101.  
      # 添加自己的全链接分类层
    102.  
      x = GlobalAveragePooling2D()(x)
    103.  
      x = Dense(1024, activation='relu')(x)
    104.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    105.  
       
    106.  
      # 训练模型
    107.  
      model = Model(inputs=base_model.input, outputs=predictions)
    108.  
      sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
    109.  
      model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    110.  
       
    111.  
      # 绘图
    112.  
      if is_plot_model:
    113.  
      plot_model(model, to_file='inception_v3_model.png', show_shapes=True)
    114.  
       
    115.  
      return model
    116.  
       
    117.  
      #训练模型
    118.  
      def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
    119.  
      # 载入模型
    120.  
      if is_load_model and os.path.exists(model_url):
    121.  
      model = load_model(model_url)
    122.  
       
    123.  
      history_ft = model.fit_generator(
    124.  
      train_generator,
    125.  
      steps_per_epoch=steps_per_epoch,
    126.  
      epochs=epochs,
    127.  
      validation_data=validation_generator,
    128.  
      validation_steps=validation_steps)
    129.  
      # 模型保存
    130.  
      model.save(model_url,overwrite=True)
    131.  
      return history_ft
    132.  
       
    133.  
      # 画图
    134.  
      def plot_training(self, history):
    135.  
      acc = history.history['acc']
    136.  
      val_acc = history.history['val_acc']
    137.  
      loss = history.history['loss']
    138.  
      val_loss = history.history['val_loss']
    139.  
      epochs = range(len(acc))
    140.  
      plt.plot(epochs, acc, 'b-')
    141.  
      plt.plot(epochs, val_acc, 'r')
    142.  
      plt.title('Training and validation accuracy')
    143.  
      plt.figure()
    144.  
      plt.plot(epochs, loss, 'b-')
    145.  
      plt.plot(epochs, val_loss, 'r-')
    146.  
      plt.title('Training and validation loss')
    147.  
      plt.show()
    148.  
       
    149.  
       
    150.  
      if __name__ == '__main__':
    151.  
      image_size = 197
    152.  
      batch_size = 32
    153.  
       
    154.  
      transfer = PowerTransferMode()
    155.  
       
    156.  
      #得到数据
    157.  
      train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)
    158.  
      validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)
    159.  
       
    160.  
      #VGG19
    161.  
      #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
    162.  
      #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)
    163.  
       
    164.  
      #ResNet50
    165.  
      model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
    166.  
      history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model=False)
    167.  
       
    168.  
      #InceptionV3
    169.  
      #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)
    170.  
      #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)
    171.  
       
    172.  
      # 训练的acc_loss图
    173.  
      transfer.plot_training(history_ft)
  • 相关阅读:
    AppleScript
    iOS 架构之文件结构
    Swift
    ERROR ITMS-90032: "Invalid Image Path
    ios中微信原生登陆的坑,ShareSDK的坑
    ios中OC给js传值的方法
    mac电脑中xcode怎么恢复还原快捷键设置
    ios 中 数组、字典转成json格式上传到后台,遇到的问题
    ios 中长按图片或者二维码,保存图片到手机的方法
    ios 中 Plus屏幕适配的问题,xib创建的cell在 Plus出现被拉大的情况
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/11231748.html
Copyright © 2020-2023  润新知