• tf.cond()函数解析(最清晰的解释)


    欢迎关注WX公众号:【程序员管小亮】

    近期看batch_normalization的代码时碰到一个tf.cond()函数,特此记录。

    tf.cond()函数用于控制数据流向。

    通过网上查了一些文章之后,才发现使用tf.cond() 函数是控制数据流向。也就是说在TensorFlow中,tf.cond()类似于c语言中的if…else…,但是也仅仅只是类似而已。

    首先看一下官方文档:

    # 用于有条件的执行函数,当pred为True时,执行true_fn函数,否则执行false_fn函数
    tf.cond(
        pred,
        true_fn=None,
        false_fn=None,
        strict=False,
        name=None,
        fn1=None,
        fn2=None
    )
    

    参数:

    • pred:标量决定是否返回 true_fn 或 false_fn 结果。
    • true_fn:如果 pred 为 true,则被调用。
    • false_fn:如果 pred 为 false,则被调用。
    • strict:启用/禁用 “严格”模式的布尔值。
    • name:返回的张量的可选名称前缀。

    Return:
    通过调用 true_fn 或 false_fn 返回的张量。如果 callables 返回单一实例列表, 则从列表中提取元素。

    需要注意的是,pred参数是tf.bool型变量,直接写“True”或者“False”是python型bool,会报错的。因此可以是很使用tf.equal(is_training,True)的操作。

    再看一下官方例子:

    z = tf.multiply(a, b)
        result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
    

    If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.multiply operation is always executed, unconditionally.
    如果x < y,将会执行tf.add操作,不会执行tf.square操作。因为cond中至少有一个分支需要z,而tf.multiply操作总是被无条件地执行。

    但是我们来看这个操作,其实是反直觉的,因为按照一般的逻辑来说,应该是用不到就不执行了,通过查询官方文档,我们看到了这么一番话:

    Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics。
    虽然这种行为与 TensorFlow 的数据流模型是一致的,但有时候,还是会让有些期望慵懒的用户惊讶。

    真是善解人意、、、那么我们来看看例子理解一下:

    import tensorflow as tf
    
    a=tf.constant(2)    
    b=tf.constant(3)    
    x=tf.constant(4)    
    y=tf.constant(5)    
    z = tf.multiply(a, b)    
    result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))    
    with tf.Session() as session:    
        print(result.eval())
        print(z.eval())
        print(y.eval())
    
    > 10
    > 6
    > 5
    

    首先z = a * b = 2 * 3 = 6,然后在tf.cond()函数中,因为x<y(4<5)成立,所以执行lambda: tf.add(x, z),也就是result = x + z = 10,而不执行lambda: tf.square(y),但是执行了z = tf.multiply(a, b)。

    python课程推荐。
    在这里插入图片描述

  • 相关阅读:
    使用 Eclipse 调试 Java 程序的 10 个技巧
    oracle9i,10g再谈优化模式参数问题.
    oracle 索引
    解决IE不能在新窗口中向父窗口的下拉框添加项的问题
    获取文档的尺寸:利用Math.max的另一种方式
    揭开constructor属性的神秘面纱
    测试杂感:Windows8也许需要Account Hub
    探索式测试:探索是为了学习
    一次有教益的程序崩溃调试 (下)
    软件测试读书列表 (2013.8)
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13302852.html
Copyright © 2020-2023  润新知