• Tensorflow系列——Saver的用法


    摘抄自:https://blog.csdn.net/u011500062/article/details/51728830/

    1、实例

     1 import tensorflow as tf
     2 import numpy as np
     3 
     4 x = tf.placeholder(tf.float32, shape=[None, 1])
     5 y = 4 * x + 4
     6 
     7 w = tf.Variable(tf.random_normal([1], -1, 1))
     8 b = tf.Variable(tf.zeros([1]))
     9 y_predict = w * x + b
    10 
    11 loss = tf.reduce_mean(tf.square(y - y_predict))
    12 optimizer = tf.train.GradientDescentOptimizer(0.5)
    13 train = optimizer.minimize(loss)
    14 
    15 isTrain = False
    16 train_steps = 100
    17 checkpoint_steps = 50
    18 checkpoint_dir = './checkpoint_dir/'
    19 
    20 saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    21 x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
    22 
    23 with tf.Session() as sess:
    24     sess.run(tf.initialize_all_variables())
    25     if isTrain:
    26         for i in range(train_steps):
    27             sess.run(train, feed_dict={x: x_data})
    28             if (i + 1) % checkpoint_steps == 0:
    29                 saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i + 1)
    30     else:
    31         ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    33         if ckpt and ckpt.model_checkpoint_path:
    34             saver.restore(sess, ckpt.model_checkpoint_path)
    35             print("Restore Sucessfully")
    36         else:
    37             pass
    38         print(sess.run(w))
    39         print(sess.run(b))

    2、运行结果

     

     

    3、解释

    训练阶段,每经过checkpoint_steps 步保存一次变量,保存的文件夹为checkpoint_dir

    测试阶段,ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么,然后载入变量

  • 相关阅读:
    爬取药智网中的方剂信息
    日报3.13
    数据库添加出错
    Bencode
    一些安全网络协议
    代码质量不重要
    Jordan Peterson
    随身记录的缺点
    Why is Go PANICking?
    go问
  • 原文地址:https://www.cnblogs.com/wt-seu/p/10381945.html
Copyright © 2020-2023  润新知