• 『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究


    查看非叶节点梯度的两种方法

    在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:

    • 使用autograd.grad函数
    • 使用hook

    autograd.gradhook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值。

    求z对y的导数

    x = V(t.ones(3))
    w = V(t.rand(3),requires_grad=True)
    y = w.mul(x)
    z = y.sum()
    
    # hook
    # hook没有返回值,参数是函数,函数的参数是梯度值
    def variable_hook(grad):
        print("hook梯度输出:
    ",grad)
    
    hook_handle = y.register_hook(variable_hook)         # 注册hook
    z.backward(retain_graph=True)                        # 内置输出上面的hook
    hook_handle.remove()                                 # 释放
    
    print("autograd.grad输出:
    ",t.autograd.grad(z,y)) # t.autograd.grad方法
    
    hook梯度输出:
     Variable containing:
     1
     1
     1
    [torch.FloatTensor of size 3]
    
    autograd.grad输出:
     (Variable containing:
     1
     1
     1
    [torch.FloatTensor of size 3]
    ,)

    多次反向传播试验

    实际就是使用retain_graph参数,

    # 构件图
    x = V(t.ones(3))
    w = V(t.rand(3),requires_grad=True)
    y = w.mul(x)
    z = y.sum()
    
    z.backward(retain_graph=True)
    print(w.grad)
    z.backward()
    print(w.grad)
    
    Variable containing:
     1
     1
     1
    [torch.FloatTensor of size 3]
    
    Variable containing:
     2
     2
     2
    [torch.FloatTensor of size 3]
    

    如果不使用retain_graph参数,

    实际上效果是一样的,AccumulateGrad object仍然会积累梯度

    # 构件图
    x = V(t.ones(3))
    w = V(t.rand(3),requires_grad=True)
    y = w.mul(x)
    z = y.sum()
    
    z.backward()
    print(w.grad)
    y = w.mul(x)  # <-----
    z = y.sum()  # <-----
    z.backward()
    print(w.grad)
    
    Variable containing:
     1
     1
     1
    [torch.FloatTensor of size 3]
    
    Variable containing:
     2
     2
     2
    [torch.FloatTensor of size 3]

    分析:

    这里的重新建立高级节点意义在这里:实际上高级节点在创建时,会缓存用于输入的低级节点的信息(值,用于梯度计算),但是这些buffer在backward之后会被清空(推测是节省内存),而这个buffer实际也体现了上面说的动态图的"动态"过程,之后的反向传播需要的数据被清空,则会报错,这样我们上面过程就分别从:保留数据不被删除&重建数据两个角度实现了多次backward过程。

    实际上第二次的z.backward()已经不是第一次的z所在的图了,体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,动态图更类似我们自己实现的机器学习框架实践,相较于静态逻辑简单一点,只是PyTorch的静态图和我们的比会在反向传播后清空存下的数据:下次要么完全重建,要么反向传播之后指定不舍弃图z.backward(retain_graph=True)。

    总之图上的节点是依赖buffer记录来完成反向传播,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。

  • 相关阅读:
    新购服务器流程
    nginx代理证书使用方法
    一键部署lnmp脚本
    mysql主从库配置读写分离以及备份
    Linux入门教程(更新完毕)
    Git 工作流程
    Git远程操作
    常用Git命令
    js数组去重
    Sublime Text设置快捷键让html文件在浏览器打开
  • 原文地址:https://www.cnblogs.com/hellcat/p/8449801.html
Copyright © 2020-2023  润新知