• 小白学PyTorch 动态图与静态图的浅显理解


    文章来自公众号【机器学习炼丹术】,回复“炼丹”即可获得海量学习资料哦!

    本章节缕一缕PyTorch的动态图机制与Tensorflow的静态图机制(最新版的TF也支持动态图了似乎)。

    1 动态图的初步推导

    • 计算图是用来描述运算的有向无环图
    • 计算图有两个主要元素:结点(Node)和边(Edge);
    • 结点表示数据 ,如向量、矩阵、张量;
    • 边表示运算 ,如加减乘除卷积等;

    上图是用计算图表示:

    (y=(x+w)∗(w+1)y=(x+w)∗(w+1))

    其中呢,(a=x+w)(b=w+1) , (y=a∗b). (a和b是类似于中间变量的那种感觉。)

    Pytorch在计算的时候,就会把计算过程用上面那样的动态图存储起来。现在我们计算一下y关于w的梯度:

    (frac{partial y}{partial w} = frac{partial y}{partial a} frac{partial a}{partial w} + frac{partial y}{partial b} frac{partial b}{partial w})
    (=2 imes w + x + 1=5)

    (上面的计算中,w=1,x=2)

    现在我们用Pytorch的代码来实现这个过程:

    import torch
    w = torch.tensor([1.],requires_grad = True)
    x = torch.tensor([2.],requires_grad = True)
    
    a = w+x
    b = w+1
    y = a*b
    
    y.backward()
    print(w.grad)
    

    得到的结果:

    2 动态图的叶子节点

    这个图中的叶子节点,是w和x,是整个计算图的根基。之所以用叶子节点的概念,是为了减少内存,在反向传播结束之后,非叶子节点的梯度会被释放掉 , 我们依然用上面的例子解释:

    import torch
    w = torch.tensor([1.],requires_grad = True)
    x = torch.tensor([2.],requires_grad = True)
    
    a = w+x
    b = w+1
    y = a*b
    
    y.backward()
    print(w.is_leaf,x.is_leaf,a.is_leaf,b.is_leaf,y.is_leaf)
    print(w.grad,x.grad,a.grad,b.grad,y.grad)
    

    运行结果是:

    可以看到只有x和w是叶子节点,然后反向传播计算完梯度后(.backward()之后),只有叶子节点的梯度保存下来了。

    当然也可以通过.retain_grad()来保留非任意节点的梯度值。

    import torch
    w = torch.tensor([1.],requires_grad = True)
    x = torch.tensor([2.],requires_grad = True)
    
    a = w+x
    a.retain_grad()
    b = w+1
    y = a*b
    
    y.backward()
    print(w.is_leaf,x.is_leaf,a.is_leaf,b.is_leaf,y.is_leaf)
    print(w.grad,x.grad,a.grad,b.grad,y.grad)
    

    运行结果:

    3. grad_fn

    torch.tensor有一个属性grad_fn,grad_fn的作用是记录创建该张量时所用的函数,这个属性反向传播的时候会用到。例如在上面的例子中,y.grad_fn=MulBackward0,表示y是通过乘法得到的。所以求导的时候就是用乘法的求导法则。同样的,a.grad=AddBackward0表示a是通过加法得到的,使用加法的求导法则。

    import torch
    w = torch.tensor([1.],requires_grad = True)
    x = torch.tensor([2.],requires_grad = True)
    
    a = w+x
    a.retain_grad()
    b = w+1
    y = a*b
    
    y.backward()
    print(y.grad_fn)
    print(a.grad_fn)
    print(w.grad_fn)
    

    运行结果是:

    叶子节点的.grad_fn是None。

    4 静态图

    两者的区别用一句话概括就是:

    • 动态图:pytorch使用的,运算与搭建同时进行;灵活,易调节。
    • 静态图:老tensorflow使用的,先搭建图,后运算;高效,不灵活。

    静态图我们是需要先定义好运算规则流程的。比方说,我们先给出

    (a = x+w) , (b=w+1) , (y=a imes b)

    然后把上面的运算流程存储下来,然后把w=1,x=2放到上面运算框架的入口位置进行运算。而动态图是直接对着已经赋值的w和x进行运算,然后变运算变构建运算图。

    在一个课程http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture08.pdf中的第125页,有这样的一个对比例子:

    这个代码是Tensorflow的,构建运算的时候,先构建运算框架,然后再把具体的数字放入其中。整个过程类似于训练神经网络,我们要构建好模型的结构,然后再训练的时候再吧数据放到模型里面去。又类似于在旅游的时候,我们事先定要每天的行程路线,然后每天按照路线去行动。

    动态图呢,就是直接对数据进行运算,然后动态的构建出运算图。很符合我们的运算习惯。

    两者的区别在于,静态图先说明数据要怎么计算,然后再放入数据。假设要放入50组数据,运算图因为是事先构建的,所以每一次计算梯度都很快、高效;动态图的运算图是在数据计算的同时构建的,假设要放入50组数据,那么就要生成50次运算图。这样就没有那么高效。所以称为动态图

    动态图虽然没有那么高效,但是他的优点有以下:

    1. 更容易调试。
    2. 动态计算更适用于自然语言处理。(这个可能是因为自然语言处理的输入往往不定长?)
    3. 动态图更面向对象编程,我们会感觉更加自然。
    人不可傲慢。
  • 相关阅读:
    java、el表达式中保留小数的方法
    EL表达式取整数或者取固定小数位数的简单实现
    Spring框架学习之第8节
    shell脚本接收输入
    awk除去重复行
    awk过滤统计不重复的行
    Spring框架学习之第7节
    jsp中利用checkbox进行批量删除
    javaScript解决Form的嵌套
    Spring框架学习之第6节
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13548104.html
Copyright © 2020-2023  润新知