• 用TensorFlow搭建网络训练、验证并测试


    原文连接  https://blog.csdn.net/yutingzhaomeng/article/details/81708261

    本文总结tensorflow使用的相关方法,包括:

    0、定义网络输入

    1、如何利用tensorflow在已有网络入resnet基础上搭建自己的网络结构

    2、如何添加自己的网络层

    3、如何导入已有模块入resnet全连接层之前部分的参数

    4、定义网络损失

    5、定义优化算子以及衰减优化算子

    6、预测网络输出

    7、保存网络模型

    8、自定义生成训练batch

    9、训练网络

    10、利用tensorboard可视化训练过程

    0、定义网络输入

    inputs = tf.placeholder(tf.float32, [None, 224, 224, 3], name='inputs')
    labels = tf.placeholder(tf.int32, [None], name='lables')
    is_training = tf.placeholder(tf.bool, name='is_training')
        这里inputs表示输入数据,labels表示对应的label,is_training主要用于区分如drop和batchnorm层的训练测试阶段。

    1、如何利用tensorflow在已有网络入resnet基础上搭建自己的网络结构

    with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
    if config.TRAIN.net_layer == '50':
    logits, endpoints = nets.resnet_v1.resnet_v1_50(inputs, num_classes=None, is_training=is_training)
    if config.TRAIN.net_layer == '101':
    logits, endpoints = nets.resnet_v1.resnet_v1_101(inputs, num_classes=None, is_training=is_training)
    if config.TRAIN.net_layer == '152':
    logits, endpoints = nets.resnet_v1.resnet_v1_152(inputs, num_classes=None, is_training=is_training)
        以resnet为例,logits表示bottleneck特征,num_classes设置为None表示取bottleneck特征。

    2、如何添加自己的网络层

    with tf.variable_scope('Logits'):
    logits = tf.squeeze(logits, axis=[1,2])
    logits = slim.dropout(logits, keep_prob=0.5, scope='scope')
    logits = slim.fully_connected(logits, num_outputs=config.DATASET.num_classes, activation_fn=None, scope='fc')
        这里有一个scope,后面我们会发现,主要用来区别resnet已有参数,squeeze用于将1*1*512的特征拉伸为向量,我们添加dropout层和全连接层。

    3、如何导入已有模块入resnet全连接层之前部分的参数

    checkpoint_exclude_scopes = 'Logits'
    exclusions = None
    if checkpoint_exclude_scopes:
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
    variables_to_restore = []
    for var in slim.get_model_variables():
    excluded = False
    for exclusion in exclusions:
    if var.op.name.startswith(exclusion):
    excluded = True
    if not excluded:
    variables_to_restore.append(var)
    logits scope下的变量我们不考虑,其他参数restore恢复。

    4、定义网络损失

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))
    5、定义优化算子以及衰减优化算子

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
    train_step = optimizer.minimize(loss)
    batch = config.TRAIN.batch_size
    sample_size = len(os.listdir(config.DATASET.image_root))
    global_step = tf.Variable(0)
    learning_rate = tf.train.exponential_decay(1e-4, global_step,
    decay_steps=4 * sample_size / batch, decay_rate=0.98,
    staircase=True)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
        上面的表示正常定义优化算子,下面的表示衰减优化算子。其中,batch表示每个batch样本数,sample_size即样本数,global_step用于获取当前iteration,sample_size / batch即每个epoch包含的iteration数目,计算衰减时,每一个decay_steps降低一次学习率。learning_rate_current = learning_rate_start * dacay_rate ** (global_step / decay_steps)。

    6、预测网络输出

    logits = tf.nn.softmax(logits, name='logits')
    classes = tf.argmax(logits, axis=1, name='classes')
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(classes, tf.int32), labels), tf.float32))
    7、保存网络模型

    init = tf.global_variables_initializer()
    saver_restore = tf.train.Saver(var_list=variables_to_restore)
    saver = tf.train.Saver(tf.global_variables())
    8、自定义生成训练batch

    images, truths, valid_imgs, valid_trus = get_batch()
    def get_label(xml_path):
    tree = ET.parse(xml_path)
    objs = tree.findall('object')

    objs = [obj for obj in objs if 'b' in obj.find('name').text] # select all pointer pannels
    if not len(objs) == 1:
    return [[], []]
    obj = objs[0] # suppose there is only one pannel, otherwise use center selection
    label = str(float(obj.find('name').text.split('b')[-1]))
    return [label]

    def get_list():
    image_list = []
    label_list = []
    for file in os.listdir(config.DATASET.image_root):
    image_label = get_label(os.path.join(config.DATASET.label_root,file.split('.jpg')[0]+'.xml'))
    if len(image_label) > 1:
    continue
    else:
    image_label = image_label[0]
    if image_label in config.DATASET.range_dict.keys():
    label_list.append(config.DATASET.range_dict[image_label])
    else:
    label_list.append(len(config.DATASET.range_dict))
    image_list.append(os.path.join(config.DATASET.image_root,file))
    valid_num = int(len(image_list)*config.DATASET.valid_ratio)
    train_list = image_list[valid_num:]
    valid_list = image_list[:valid_num]
    train_label = label_list[valid_num:]
    valid_label = label_list[:valid_num]
    return train_list, train_label, valid_list, valid_label

    def process_batch(input_quene):

    label = input_quene[1]
    image = tf.read_file(input_quene[0])
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize_image_with_crop_or_pad(image, config.DATASET.width, config.DATASET.height)
    image = tf.image.per_image_standardization(image)

    image_batch, label_batch = tf.train.batch([image, label], batch_size=config.TRAIN.batch_size,
    capacity=config.TRAIN.capacity, num_threads=config.TRAIN.num_threads)
    label_batch = tf.reshape(label_batch, [config.TRAIN.batch_size])
    image_batch = tf.cast(image_batch, tf.float32)

    return image_batch, label_batch


    def get_batch():
    train_image_list, train_label_list, valid_image_list, valid_label_list = get_list()

    input_quene = tf.train.slice_input_producer([train_image_list, train_label_list])
    trian_image_batch, trian_label_batch = process_batch(input_quene)

    valid_quene = tf.train.slice_input_producer([valid_image_list, valid_label_list])
    valid_image_batch, valid_label_batch = process_batch(valid_quene)

    return trian_image_batch, trian_label_batch, valid_image_batch, valid_label_batch
    9、训练网络

    with tf.Session(config=tfConfig) as sess:

    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # =============================Import Pretrained Parameter=========================== #
    saver_restore.restore(sess, config.TRAIN.model_path)

    # ================================TensorBoard Related================================ #
    tf.summary.image('inputs',inputs)
    tf.summary.scalar('loss',loss)
    tf.summary.scalar('accuracy',accuracy)
    tf.summary.scalar('learning rate', learning_rate)
    merged_summary_op = tf.summary.merge_all()
    if os.path.exists(os.path.join(config.TRAIN.log_path, 'train')):
    shutil.rmtree(os.path.join(config.TRAIN.log_path, 'train'))
    if os.path.exists(os.path.join(config.TRAIN.log_path, 'valid')):
    shutil.rmtree(os.path.join(config.TRAIN.log_path, 'valid'))
    train_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, 'train'), sess.graph)
    valid_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, 'valid'))

    for i in range(config.TRAIN.num_iterations):
    images_, truths_ = sess.run([images, truths])
    valid_imgs_, valid_trus_ = sess.run([valid_imgs, valid_trus])

    summary_str, _, loss_, acc_ = sess.run([merged_summary_op, train_step, loss, accuracy],
    feed_dict={inputs: images_, labels: truths_, is_training: True})
    valid_str, vloss, vacc = sess.run([merged_summary_op, loss, accuracy],
    feed_dict={inputs: valid_imgs_, labels: valid_trus_, is_training: False})

    print('Step: {}, Loss: {:.4f}, Accuracy: {:.4f}, Valid Loss: {:.4f}, Valid Accuracy: {:.4f}'.format(i+1, loss_, acc_, vloss, vacc))

    # if (i+1) % 1000 == 0:
    # saver.save(sess, config.TRAIN.save_path)
    # print('save mode to {}'.format(config.TRAIN.save_path))

    # summary_str = sess.run(merged_summary_op)
    train_writer.add_summary(summary_str, i)
    valid_writer.add_summary(valid_str, i)


    coord.request_stop()
    coord.join(threads)
    10、利用tensorboard可视化训练过程

    tf.summary.image('inputs',inputs)
    tf.summary.scalar('loss',loss)
    tf.summary.scalar('accuracy',accuracy)
    tf.summary.scalar('learning rate', learning_rate)
    merged_summary_op = tf.summary.merge_all()
    if os.path.exists(os.path.join(config.TRAIN.log_path, 'train')):
    shutil.rmtree(os.path.join(config.TRAIN.log_path, 'train'))
    if os.path.exists(os.path.join(config.TRAIN.log_path, 'valid')):
    shutil.rmtree(os.path.join(config.TRAIN.log_path, 'valid'))
    train_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, 'train'), sess.graph)
    valid_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, 'valid'))

  • 相关阅读:
    PHP简单模拟登录功能实例分享
    一个form表单,多个提交按钮
    jquery validation验证身份证号、护照、电话号码、email
    MockMvc和Mockito之酷炫使用
    Java8 Stream API
    第一章 Lambda表达式
    Java中线程顺序执行
    单元测试之获取Spring下所有Bean
    iBatis之type
    json解析之jackson ObjectMapper
  • 原文地址:https://www.cnblogs.com/happytaiyang/p/11618659.html
Copyright © 2020-2023  润新知