首更:
由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe...好吧,caffe也没有这种python风格的设定...
废话少说,导入包:
1 import numpy as np 2 import tensorflow as tf
保存会话:
1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32) 2 b = tf.Variable([[1,2,3]],dtype=tf.float32) 3 4 init = tf.global_variables_initializer() 5 saver = tf.train.Saver() # <--------- 6 7 with tf.Session() as sess: 8 sess.run(init) 9 save_path = saver.save(sess,'./my_net/saver_net.ckpt') # <---------
载入会话:
1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32) 2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32) 3 4 saver = tf.train.Saver() 5 6 with tf.Session() as sess: 7 saver.restore(sess,'./my_net/saver_net.ckpt') # <--------- 8 print('Weight: ',sess.run(W)) 9 print('biases: ',sess.run(b))
输出如下:
Weight: [[ 1. 2. 3.] [ 4. 5. 6.]] biases: [[ 1. 2. 3.]]
载入会话会加载之前保存的变量,所以不需要tf.global_variables_initializer()激活本次变量了。
再更:
引入节点名称后,只要tf变量节点的名称一致,python变量名不一致也能完美继承,也就是说tf变量节点的名称识别权限大于python变量名
详细的命名规则下节有介绍:『TensorFlow』第八弹_变量与命名空间_固有结界
保存模型:
1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name='W') # <------ 2 b = tf.Variable([[1,2,3]],dtype=tf.float32,name='b') # <------ 3 4 init = tf.global_variables_initializer() 5 saver = tf.train.Saver() 6 7 with tf.Session() as sess: 8 sess.run(init) 9 save_path = saver.save(sess,'./my_net/saver_net.ckpt')
W--’W‘,b--’b‘
载入模型:
1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32') # <------ 2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32') # <------ 3 4 saver = tf.train.Saver() 5 6 with tf.Session() as sess: 7 saver.restore(sess,'./my_net/saver_net.ckpt') 8 print('Weight: ',sess.run(W)) 9 print('biases: ',sess.run(b))
W,b
结果报错
载入模型:
1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name='W') # <------ 2 a = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name='b') # <------ 3 4 saver = tf.train.Saver() 5 6 with tf.Session() as sess: 7 saver.restore(sess,'./my_net/saver_net.ckpt') 8 print('Weight: ',sess.run(W)) 9 print('biases: ',sess.run(a))
W-’W‘,a--’b'
1 INFO:tensorflow:Restoring parameters from ./my_net/saver_net.ckpt 2 Weight: 3 [[ 1. 2. 3.] 4 [ 4. 5. 6.]] 5 biases: 6 [[ 1. 2. 3.]]