我们通过tf.train.Saver()来保存和重载变量
实现是保存
# Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path)
通过调用saver的save方法来保存,返回一个str,代表了路径。
然后展示的是我们保存部分变量和重载:
tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer) # Add ops to save and restore only `v2` using the name "v2" saver = tf.train.Saver({"v2": v2}) # Use the saver object normally after that. with tf.Session() as sess: # Initialize v1 since the saver will not. v1.initializer.run() saver.restore(sess, "/tmp/model.ckpt") print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval())
如果Saver中不传入参数,则会将所有的变量都保存。传入字典,则会按照字典中的key-value对变量进行保存。
对于不需要feed数据就可以获取的值,比如Variable。我们可以直接使用variable.eval()将变量的值打印出来。
它相当于:
tf.get_default_session().run(t)