• 『TensorFlow』SSD源码学习_其八:网络训练


    Fork版本项目地址:SSD

    作者使用了分布式训练的写法,这使得训练部分代码异常臃肿,我给出了部分注释。我对于多机分布式并不很熟,而且不是重点,所以不过多介绍,简单的给出一点训练中作者的优化手段,包含优化器选择之类的。

    一、滑动平均

            # =================================================================== #
            # Configure the moving averages.
            # =================================================================== #
            if FLAGS.moving_average_decay:
                moving_average_variables = slim.get_model_variables()
                variable_averages = tf.train.ExponentialMovingAverage(
                    FLAGS.moving_average_decay, global_step)
            else:
                moving_average_variables, variable_averages = None, None
    

    二、学习率衰减

            with tf.device(deploy_config.optimizer_device()):
                learning_rate = tf_utils.configure_learning_rate(FLAGS,
                                                                 dataset.num_samples,
                                                                 global_step)
    

    细节实现函数,有三种形式,一种是常数学习率,两种不同的衰减方式(默认参数:exponential):

    def configure_learning_rate(flags, num_samples_per_epoch, global_step):
        """Configures the learning rate.
    
        Args:
          num_samples_per_epoch: The number of samples in each epoch of training.
          global_step: The global_step tensor.
        Returns:
          A `Tensor` representing the learning rate.
        """
        decay_steps = int(num_samples_per_epoch / flags.batch_size *
                          flags.num_epochs_per_decay)
    
        if flags.learning_rate_decay_type == 'exponential':
            return tf.train.exponential_decay(flags.learning_rate,
                                              global_step,
                                              decay_steps,
                                              flags.learning_rate_decay_factor,
                                              staircase=True,
                                              name='exponential_decay_learning_rate')
        elif flags.learning_rate_decay_type == 'fixed':
            return tf.constant(flags.learning_rate, name='fixed_learning_rate')
        elif flags.learning_rate_decay_type == 'polynomial':
            return tf.train.polynomial_decay(flags.learning_rate,
                                             global_step,
                                             decay_steps,
                                             flags.end_learning_rate,
                                             power=1.0,
                                             cycle=False,
                                             name='polynomial_decay_learning_rate')
    

    三、优化器选择

    optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
    

    选择很丰富(默认参数:adam):

    def configure_optimizer(flags, learning_rate):
        """Configures the optimizer used for training.
    
        Args:
          learning_rate: A scalar or `Tensor` learning rate.
        Returns:
          An instance of an optimizer.
        """
        if flags.optimizer == 'adadelta':
            optimizer = tf.train.AdadeltaOptimizer(
                learning_rate,
                rho=flags.adadelta_rho,
                epsilon=flags.opt_epsilon)
        elif flags.optimizer == 'adagrad':
            optimizer = tf.train.AdagradOptimizer(
                learning_rate,
                initial_accumulator_value=flags.adagrad_initial_accumulator_value)
        elif flags.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(
                learning_rate,
                beta1=flags.adam_beta1,
                beta2=flags.adam_beta2,
                epsilon=flags.opt_epsilon)
        elif flags.optimizer == 'ftrl':
            optimizer = tf.train.FtrlOptimizer(
                learning_rate,
                learning_rate_power=flags.ftrl_learning_rate_power,
                initial_accumulator_value=flags.ftrl_initial_accumulator_value,
                l1_regularization_strength=flags.ftrl_l1,
                l2_regularization_strength=flags.ftrl_l2)
        elif flags.optimizer == 'momentum':
            optimizer = tf.train.MomentumOptimizer(
                learning_rate,
                momentum=flags.momentum,
                name='Momentum')
        elif flags.optimizer == 'rmsprop':
            optimizer = tf.train.RMSPropOptimizer(
                learning_rate,
                decay=flags.rmsprop_decay,
                momentum=flags.rmsprop_momentum,
                epsilon=flags.opt_epsilon)
        elif flags.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        else:
            raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
        return optimizer
    

    四、训练

    实际上中间有好一段分布式梯度计算过程,这里不多介绍,大概就是在各个clone上计算出梯度,汇总梯度,再优化各个clone网络,将优化节点提出作为train_tensor等等。

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options)
    saver = tf.train.Saver(max_to_keep=5,
                           keep_checkpoint_every_n_hours=1.0,
                           write_version=2,
                           pad_step_number=False)
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_dir,
        master='',
        is_chief=True,
        init_fn=tf_utils.get_init_fn(FLAGS),            # 看函数实现就明白了,assign变量用
        summary_op=summary_op,                          # tf.summary.merge节点
        number_of_steps=FLAGS.max_number_of_steps,      # 训练step
        log_every_n_steps=FLAGS.log_every_n_steps,      # 输出训练信息间隔
        save_summaries_secs=FLAGS.save_summaries_secs,  # 每次summary时间间隔
        saver=saver,                                    # tf.train.Saver节点
        save_interval_secs=FLAGS.save_interval_secs,    # 每次model保存step间隔
        session_config=config,                          # sess参数
        sync_optimizer=None)
    

    其中调用的初始化函数如下:

    def get_init_fn(flags):
        """Returns a function run by the chief worker to warm-start the training.
        Note that the init_fn is only run when initializing the model during the very
        first global step.
    
        Returns:
          An init function run by the supervisor.
        """
        if flags.checkpoint_path is None:
            return None
        # Warn the user if a checkpoint exists in the train_dir. Then ignore.
        if tf.train.latest_checkpoint(flags.train_dir):
            tf.logging.info(
                'Ignoring --checkpoint_path because a checkpoint already exists in %s'
                % flags.train_dir)
            return None
    
        exclusions = []
        if flags.checkpoint_exclude_scopes:
            exclusions = [scope.strip()
                          for scope in flags.checkpoint_exclude_scopes.split(',')]
    
        # TODO(sguada) variables.filter_variables()
        variables_to_restore = []
        for var in slim.get_model_variables():
            excluded = False
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)
        # Change model scope if necessary.
        if flags.checkpoint_model_scope is not None:
            variables_to_restore = 
                {var.op.name.replace(flags.model_name,
                                     flags.checkpoint_model_scope): var
                 for var in variables_to_restore}
    
    
        if tf.gfile.IsDirectory(flags.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
        else:
            checkpoint_path = flags.checkpoint_path
        tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars))
    
        return slim.assign_from_checkpoint_fn(
            checkpoint_path,
            variables_to_restore,
            ignore_missing_vars=flags.ignore_missing_vars)
    

    至此,SSD项目介绍完毕,训练命令如下,不过默认训练step是无限的,不手动终止会一直训练下去,所以要关注一下训练的指标,够用了就关了吧,

    DATASET_DIR=./tfrecords
    TRAIN_DIR=./logs/
    CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt
    python train_ssd_network.py 
        --train_dir=${TRAIN_DIR} 
        --dataset_dir=${DATASET_DIR} 
        --dataset_name=pascalvoc_2012 
        --dataset_split_name=train 
        --model_name=ssd_300_vgg 
        --checkpoint_path=${CHECKPOINT_PATH} 
        --save_summaries_secs=60 
        --save_interval_secs=600 
        --weight_decay=0.0005 
        --optimizer=adam 
        --learning_rate=0.001 
        --batch_size=32
    

    如何使用训练好模型见集智专栏的文章最后一部分。

  • 相关阅读:
    Java Web Action DAO Service层次理解
    JSP/Servlet Web 学习笔记 DaySix —— EL表达式
    JSP/Servlet Web 学习笔记 DayFive
    JSP/Servlet Web 学习笔记 DayFour —— 实现一个简单的JSP/Servlet交互
    JSP/Servlet Web 学习笔记 DayFour
    如何快速创建百万级测试数据
    公司线上虚拟机大量GC导致STW和CPU飙升--抽丝剥茧定位的过程
    这10道springboot常见面试题你需要了解下
    这是一篇来源于阿里内部技术论坛的文章
    一文彻底弄懂如何选择抽象类还是接口
  • 原文地址:https://www.cnblogs.com/hellcat/p/9360640.html
Copyright © 2020-2023  润新知