欢迎关注WX公众号:【程序员管小亮】
近期看batch_normalization的代码时碰到一个tf.cond()函数,特此记录。
tf.cond()函数用于控制数据流向。
通过网上查了一些文章之后,才发现使用tf.cond() 函数是控制数据流向。也就是说在TensorFlow中,tf.cond()类似于c语言中的if…else…,但是也仅仅只是类似而已。
首先看一下官方文档:
# 用于有条件的执行函数,当pred为True时,执行true_fn函数,否则执行false_fn函数
tf.cond(
pred,
true_fn=None,
false_fn=None,
strict=False,
name=None,
fn1=None,
fn2=None
)
参数:
- pred:标量决定是否返回 true_fn 或 false_fn 结果。
- true_fn:如果 pred 为 true,则被调用。
- false_fn:如果 pred 为 false,则被调用。
- strict:启用/禁用 “严格”模式的布尔值。
- name:返回的张量的可选名称前缀。
Return:
通过调用 true_fn 或 false_fn 返回的张量。如果 callables 返回单一实例列表, 则从列表中提取元素。
需要注意的是,pred参数是tf.bool型变量,直接写“True”或者“False”是python型bool,会报错的。因此可以是很使用tf.equal(is_training,True)的操作。
再看一下官方例子:
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.multiply operation is always executed, unconditionally.
如果x < y,将会执行tf.add操作,不会执行tf.square操作。因为cond中至少有一个分支需要z,而tf.multiply操作总是被无条件地执行。
但是我们来看这个操作,其实是反直觉的,因为按照一般的逻辑来说,应该是用不到就不执行了,通过查询官方文档,我们看到了这么一番话:
Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics。
虽然这种行为与 TensorFlow 的数据流模型是一致的,但有时候,还是会让有些期望慵懒的用户惊讶。
真是善解人意、、、那么我们来看看例子理解一下:
import tensorflow as tf
a=tf.constant(2)
b=tf.constant(3)
x=tf.constant(4)
y=tf.constant(5)
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
with tf.Session() as session:
print(result.eval())
print(z.eval())
print(y.eval())
> 10
> 6
> 5
首先z = a * b = 2 * 3 = 6,然后在tf.cond()函数中,因为x<y(4<5)成立,所以执行lambda: tf.add(x, z),也就是result = x + z = 10,而不执行lambda: tf.square(y),但是执行了z = tf.multiply(a, b)。
python课程推荐。