在 Tensorflow 当中有两种途径生成变量 variable, 一种是 tf.get_variable()
, 另一种是 tf.Variable()
.
使用tf.get_variable()
定义的变量不会被tf.name_scope()
当中的名字所影响
1 import tensorflow as tf 2 3 with tf.name_scope("a_name_scope"): 4 initializer = tf.constant_initializer(value=1) 5 var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32, initializer=initializer) 6 var2 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32) 7 var21 = tf.Variable(name='var2', initial_value=[2.1], dtype=tf.float32) 8 var22 = tf.Variable(name='var2', initial_value=[2.2], dtype=tf.float32) 9 10 11 with tf.Session() as sess: 12 sess.run(tf.initialize_all_variables()) 13 print(var1.name) # var1:0 14 print(sess.run(var1)) # [ 1.] 15 print(var2.name) # a_name_scope/var2:0 16 print(sess.run(var2)) # [ 2.] 17 print(var21.name) # a_name_scope/var2_1:0 18 print(sess.run(var21)) # [ 2.0999999] 19 print(var22.name) # a_name_scope/var2_2:0 20 print(sess.run(var22)) # [ 2.20000005]
想要达到重复利用变量的效果, 我们就要使用 tf.variable_scope()
, 并搭配 tf.get_variable()
这种方式产生和提取变量. 不像 tf.Variable()
每次都会产生新的变量, tf.get_variable()
如果遇到了同样名字的变量时, 它会单纯的提取这个同样名字的变量(避免产生新变量). 而在重复使用的时候, 一定要在代码中强调 scope.reuse_variables()
, 否则系统将会报错, 以为你只是单纯的不小心重复使用到了一个变量.
1 with tf.variable_scope("a_variable_scope") as scope: 2 initializer = tf.constant_initializer(value=3) 3 var3 = tf.get_variable(name='var3', shape=[1], dtype=tf.float32, initializer=initializer) 4 scope.reuse_variables() 5 var3_reuse = tf.get_variable(name='var3',) 6 var4 = tf.Variable(name='var4', initial_value=[4], dtype=tf.float32) 7 var4_reuse = tf.Variable(name='var4', initial_value=[4], dtype=tf.float32) 8 9 with tf.Session() as sess: 10 sess.run(tf.global_variables_initializer()) 11 print(var3.name) # a_variable_scope/var3:0 12 print(sess.run(var3)) # [ 3.] 13 print(var3_reuse.name) # a_variable_scope/var3:0 14 print(sess.run(var3_reuse)) # [ 3.] 15 print(var4.name) # a_variable_scope/var4:0 16 print(sess.run(var4)) # [ 4.] 17 print(var4_reuse.name) # a_variable_scope/var4_1:0 18 print(sess.run(var4_reuse)) # [ 4.]
或
1 with tf.variable_scope('foo') as foo_scope: 2 v = tf.get_variable('v', [1]) 3 with tf.variable_scope('foo', reuse=True): 4 v1 = tf.get_variable('v') 5 assert v1 == v
1. 使用tf.Variable()
的时候,tf.name_scope()
和tf.variable_scope()
都会给 Variable
和 op
的 name
属性加上前缀。
2. 使用tf.get_variable()
的时候,tf.name_scope()
就不会给 tf.get_variable()
创建出来的Variable
加前缀。