• 由浅入深之Tensorflow(4)----Saver&restore


    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    
    w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
    b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
    y_hat = tf.add(b, tf.matmul(x, w))
    
    ...more setup for optimization and what not...
    
    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if FLAGS.train:
            for i in xrange(FLAGS.training_steps):
                ...training loop...
                if (i + 1) % FLAGS.checkpoint_steps == 0:
                    saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                               global_step=i+1)
        else:
            # Here's where you're restoring the variables w and b.
            # Note that the graph is exactly as it was when the variables were
            # saved in a prior training run.
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                ...no checkpoint found...
    
            # Now you can run the model to get predictions
            batch_x = ...load some data...
            predictions = sess.run(y_hat, feed_dict={x: batch_x})
  • 相关阅读:
    C++初始化列表
    正确理解Widget::Widget(QWidget *parent) :QWidget(parent)这句话
    C++ 的关键字(保留字)完整介绍
    Qt之UI文件设计和运行机制
    QT 5.12安装
    Win2016 安装VM与Hyper-V冲突解决办法
    多线程与并行
    Framework使用
    MVVMLight
    Knockout 应用
  • 原文地址:https://www.cnblogs.com/upright/p/6140428.html
Copyright © 2020-2023  润新知