定义计算图并计算,保存其中的变量 。保存.ipynb
import tensorflow as tf tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "./ckpt_test/model.ckpt") print("Model saved in path: %s" % save_path)
创建相同的图结构,图中的节点变量可以由已经保存的模型文件中的内容恢复处理,注意 首先要图进行清空(感觉tf公用了变量空间,所以如果没有清空会导致变量内容名称不一致)恢复.ipynb
import tensorflow as tf tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3]) v2 = tf.get_variable("v2", shape=[5]) # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "./ckpt_test/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval())
所以最好在保存和恢复的文件中都先对图清空。