• tensorflow slim代码使用


    此处纯粹作为个人学习使用,原文连接:https://www.jianshu.com/p/dc24e54aec81

    这篇文章是借鉴很多博文的,作为一个关于slim库的总结

    导入slim模块

    import tensorflow.contrib.slim as slim

    定义slim的变量

    #Model Variables
    weights = slim.model_variable('weights', shape = [10, 10, 3, 3],
                                                initializer = tf.truncated_normal_initializer(stddev=0.1)
                                                regularizer = slim.l2_regularizer(0.05),
                                                device='/CPU:0')
    model_variables = slim.get_model_variables()    #获取变量吗?
    
    # Regular variables
    my_var = slim.variable('my_var", shape=[20, 1],
                                         initializer = tf.zeros_initializer())
    regular_variables_and_model_variables = slim.get_variables()

    # 这里的model_variable是作为模型参数保存的,variable是局部变量,不会保存。

    Slim中实现一个层

    input = ...
    net = slim.conv2d(input, 128, [3,3], scope='conv1_1')
    
    # 代码重用
    net = slim.repeat(net, 3, slim.conv2d, 256, [3,3], scope='conv3')
    net = slim.max_pool2d(net, [2, 2], scope='pool2')
    
    # 处理不同参数情况
    x = slim.fully_connected(x, 32, scope='fc/fc_1')
    x = slim.fully_connected(x, 64, scope ='fc/fc_2')
    x = slim.fuly_connected(x, 128, scope = 'fc/fc_3')
    # or
    slim.stack(x, slim.fully_connected, [32, 64, 128], scope='fc')
    
    # 普通方法
    x = slim.conv2d(x, 32, [3, 3], scope='core/core_1')
    x = slim.conv2d(x, 32, [1, 1], scope='core/core_2')
    x = slim.conv2d(x, 64, [3, 3], scope='core/core_3')
    x = slim.conv2d(x, 64, [1, 1], scope='core/core_4')
    
    # 简便方法:
    slim.stack(x, slim.conv2d, [(32, [3,3]), (32, [1,1]), (64, [3,3]), (64, [1,1]), scopre='core')

    定义相同参数的简化

    with slim.arg_scope([slim.conv2d],  padding='SAME',
                                 weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                                 weights_regularizer=slim.l2_regularizer(0.0005)):
            net = slim.conv2d(inputs, 64, [11, 11], scope='conv1')
            net = slim.conv2d(net, [11,11], padding=' VALID', scope='conv2')
            net = slim.conv2d(net, 256, [11, 11], scope='conv3')
    
    
    # arg_scope的嵌套
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                                      activation_fn=tf.nn.rely,
                                  weights_initializer=tf.truncated_normal_initialier(stddev=0.01),
                                  weights_regularizer=slim.l2_regularizer(0.0005)):
         with slim.arg_scope([slim.conv2d], stride=1, padding='SAME'):
              net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
              net = slim.conv2d(net, 256, [5, 5],
                          weights_initializer=tf.truncated_normal_initializer(stddev=0.03),
                          scope='conv2')
              net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc')

    训练模型

    loss = slim.losses.softmax_cross_entropy(predictions, labels)
    # 自定义loss模型
    # define the loss functions and get the total loss.
    classification_loss = slim.losses.softmax_cross_entropy(scene_predictions, scene_labels)
    sum_of_squares_loss = slim.losses.sum_of_squares(depth_predictions, depth_labels)
    pose_loss = MyCustomLossFunction(pose_predictions, pose_labels)
    slim.losses.add_loss(pose_loss)  # Letting TF-Slim know about the additional loss.
    
    # The following two ways to compute the total loss are equivalent:
    regularization_loss = tf.add_n(slim.losses.get_regularization_losses())
    total_loss1 = classification_loss + sum_of_squares_loss + poses_loss + regularization_loss

    # slim读取保存模型的方法

    # Create some variables.
    v1 = slim.variable(name='v1', ...)
    v2 = slim.variable(name=''nested/v2', ...)
    ...
    
    # Get list of variables to restore (which contains only 'v2')
    variables_to_restore = slim.get_variables_by_name("v2")
    
    # Create the saver which will be used to restore the varialbes.
    restorer = tf.train.Saver(variables_to_restore)
    
    with tf.Session() as sess:
        # Restore variables from disk.
        restores.restore(sess, "/tmp/model.ckpt")
        print("Model restored.")
    
    # 为模型添加变量前缀
    # 假设我们定义的网络变量是conv1/weights, 而从VGG记载的变量名为#vgg16/conv1/weights, 正常load肯定会报错
    def name_in_checkpoint(var):
        return 'vgg16/' + var.op.name
    
    variables_to_restore = slim.get_model_variables()
    variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
    restorer = tf.train.Saver(variables_to_restore)
    
    with tf.Session() as sess:
        # Restore variables from disk.
        restorer.restore(sess, "/tmp/model.ckpt")

    训练模型

    在该例中, slim.learning.train根据train_op计算损失、应用梯度step. logdir指定checkpoints和event文件的存储路径。我们可以限制梯度step到任何数值。这里我们采用1000步。最后, save_summaries_secs=300表示每5分钟计算一次summaries, save_interval_secs=600表示每10分钟保存一次模型的checkpoint

    g = tf.Graph()
     
    # Create the model and specify the losses...
    ...
     
    total_loss = slim.losses.get_total_loss()
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
     
    # create_train_op ensures that each time we ask for the loss, the update_ops
    # are run and the gradients being computed are applied too.
    train_op = slim.learning.create_train_op(total_loss, optimizer)
    logdir = ... # Where checkpoints are stored.
     
    slim.learning.train(
        train_op,
        logdir,
        number_of_steps=1000,
        save_summaries_secs=300,
        save_interval_secs=600)

    Fine-Tuning a Model on a different task

    假设我们有一个已经预训练好的VGG16的模型。这个模型是在拥有1000分类的ImageNet数据集上进行训练的。但是,现在我们想把它应用只具有20个分类的Pascal VOC数据集上。为了能这样做,我们可以通过利用除最后一些全连接层的其它预训练模型来初始化新模型的达到目的:

    # Load the Pascal VOC data
    image, label = MyPascalVocDataLoader(...)
    images, labels = tf.train.batch([image, label], batch_size = 32)
    
    # Create the model
    predictions = vgg.vgg_16(images)
    train_op = slim.learning.create_train_op(...)
    
    # Specify where the Model, trained on ImageNet, was saved.
    model_path = '/path/to/pre_trained_on_imagenet.checkpoint'
    metric_ops.py
    # Specify where the new model will live:
    log_dir = from_checkpoint_'/path/to/my_pascal_model_dir/'
    
    # Restore only the convolutional layers:
    variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])
    init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore)
    
    # Start training.
    slim.learning.train(train_op, log_dir, init_fn=init_fn)

    evaluation loop

    import tensorflow as tf
    
    slim = tf.contrib.slim
    
    # Load the data
    images, labels = load_data(...)
    
    # Define the network
    predictions = MyModel(images)
    
    # Choose the metrics to compute:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'accuracy': slim.metrics.accuracy(predictions, labels),
        'precision': slim.metrics.precision(predictions, labels),
        'recall': slim.metrics.recall(mean_relative_errors, 0.3),
    })
    
    # Create the summary ops such that they also print out to std output:
    summary_ops = []
    for metric_name, metric_value in names_to_values.iteritems():
      op = tf.summary.scalar(metric_name, metric_value)
      op = tf.Print(op, [metric_value], metric_name)
      summary_ops.append(op)
    
    num_examples = 10000
    batch_size = 32
    num_batches = math.ceil(num_examples / float(batch_size))
    
    # Setup the global step.
    slim.get_or_create_global_step()
    
    output_dir = ... # Where the summaries are stored.
    eval_interval_secs = ... # How often to run the evaluation.
    slim.evaluation.evaluation_loop(
        'local',
        checkpoint_dir,
        log_dir,
        num_evals=num_batches,
        eval_op=names_to_updates.values(),
        summary_op=tf.summary.merge(summary_ops),
        eval_interval_secs=eval_interval_secs)
  • 相关阅读:
    大数据(7)
    大数据(6)
    大数据(5)
    大数据(4)
    头发护理 -- 生发养发
    Sublime 中 SFTP插件的使用
    大数据(3)
    Apache Spark源码走读之5 -- DStream处理的容错性分析
    Apache Spark源码走读之4 -- DStream实时流数据处理
    Apache Spark源码走读之3 -- Task运行期之函数调用关系分析
  • 原文地址:https://www.cnblogs.com/elitphil/p/12028419.html
Copyright © 2020-2023  润新知