• 『PyTorch』第一弹_静动态图构建if逻辑对比


    对比TensorFlow和Pytorch的动静态图构建上的差异

    静态图框架设计好了不能够修改,且定义静态图时需要使用新的特殊语法,这也意味着图设定时无法使用if、while、for-loop等结构,而是需要特殊的由框架专门设计的语法,在构建图时,我们需要考虑到所有的情况(即各个if分支图结构必须全部在图中,即使不一定会在每一次运行时使用到),使得静态图异常庞大占用过多显存。

    以动态图没有这个顾虑,它兼容python的各种逻辑控制语法,最终创建的图取决于每次运行时的条件分支选择,下面我们对比一下TensorFlow和Pytorch的if条件分支构建图的实现:

    # Author : Hellcat
    # Time   : 2018/2/9
    
    def tf_graph_if():
        import numpy as np
        import tensorflow as tf
    
        x = tf.placeholder(tf.float32, shape=(3, 4))
        z = tf.placeholder(tf.float32, shape=None)
        w1 = tf.placeholder(tf.float32, shape=(4, 5))
        w2 = tf.placeholder(tf.float32, shape=(4, 5))
    
        def f1():
            return tf.matmul(x, w1)
    
        def f2():
            return tf.matmul(x, w2)
    
        y = tf.cond(tf.less(z, 0), f1, f2)
    
        with tf.Session() as sess:
            y_out = sess.run(y, feed_dict={
                x: np.random.randn(3, 4),
                z: 10,
                w1: np.random.randn(4, 5),
                w2: np.random.randn(4, 5)})
        return y_out
    
    def t_graph_if():
        import torch as t
        from torch.autograd import Variable
    
        x = Variable(t.randn(3, 4))
        w1 = Variable(t.randn(4, 5))
        w2 = Variable(t.randn(4, 5))
    
        z = 10
        if z > 0:
            y = x.mm(w1)
        else:
            y = x.mm(w2)
    
        return y
    
    
    if __name__ == "__main__":
        print(tf_graph_if())
        print(t_graph_if())
    

     计算输出如下:

    [[ 4.0871315   0.90317607 -4.65211582  0.71610922 -2.70281982]
     [ 3.67874336 -0.58160967 -3.43737102  1.9781189  -2.18779659]
     [ 2.6638422  -0.81783319 -0.30386463 -0.61386991 -3.80232286]]
    Variable containing:
    -0.2474  0.1269  0.0830  3.4642  0.2255
     0.7555 -0.8057 -2.8159  3.7416  0.6230
     0.9010 -0.9469 -2.5086 -0.8848  0.2499
    [torch.FloatTensor of size 3x5]

    个人感觉上面的对比不太完美,如果使用TensorFlow的变量来对比,上面函数应该改写如下,

    # Author : Hellcat
    # Time   : 2018/2/9
    
    def tf_graph_if():
        import tensorflow as tf
    
        x = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[3, 4]))
        z = tf.constant(dtype=tf.float32, value=10)
        w1 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))
        w2 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))
    
        def f1():
            return tf.matmul(x, w1)
    
        def f2():
            return tf.matmul(x, w2)
    
        y = tf.cond(tf.less(z, 0), f1, f2)
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            y_out = sess.run(y)
        return y_out
    

     输出没什么变化,

    [[ 1.89582038  1.12734962  0.59730953  0.99833554  0.86517167]
     [ 1.2659111   0.77320379  0.63649696  0.5804953   0.82271856]
     [ 1.92151642  1.64715886  1.19869363  1.31581473  1.5636673 ]]

    可以看到,TensorFlow的if条件分支使用函数tf.cond(tf.less(z, 0), f1, f2)来实现,这和Pytorch直接使用if的逻辑很不同,而且,动态图不需要feed,直接运行便可。简单对比,可以看到Pytorch的逻辑更为简洁,让人很感兴趣。

  • 相关阅读:
    MyBatis Generator去掉生成的注解
    IDEA git修改远程仓库地址
    Spring Boot 集成druid
    解决 SpringBoot 没有主清单属性
    Intellij IDEA 安装lombok及使用详解
    SET FOREIGN_KEY_CHECKS=0;在Mysql中取消外键约束
    @SpringBootApplication
    IDEA 创建git 分支 拉取分支
    Intellij Idea 将java项目打包成jar
    Spring Could Stream 基本用法
  • 原文地址:https://www.cnblogs.com/hellcat/p/8436955.html
Copyright © 2020-2023  润新知