• 『PyTorch』第三弹重置_Variable对象


    『PyTorch』第三弹_自动求导

    torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

    Varibale包含三个属性:

    • data:存储了Tensor,是本体的数据
    • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
    • grad_fn:指向Function对象,用于反向传播的梯度计算之用

    data

    import torch as t
    from torch.autograd import Variable
    
    x = Variable(t.ones(2, 2), requires_grad = True)
    x  # 实际查询的是.data,是个Tensor
    

    实际上查询x和查询x.data返回结果一致,

    Variable containing:

     1 1

     1 1

     [torch.FloatTensor of size 2x2]

    梯度求解

    构建一个简单的方程:y = x[0,0] + x[0,1] + x[1,0] + x[1,1],Variable的运算结果也是Variable,但是,中间结果反向传播中不会被求导()

    这和TensorFlow不太一致,TensorFlow中中间运算果数据结构均是Tensor,

    y = x.sum()
    
    y
    """
      Variable containing:
       4
      [torch.FloatTensor of size 1]
    """
    

    可以查看目标函数的.grad_fn方法,它用来求梯度,

    y.grad_fn
    """
        <SumBackward0 at 0x18bcbfcdd30>
    """
    
    y.backward()  # 反向传播
    x.grad  # Variable的梯度保存在Variable.grad中
    """
      Variable containing:
       1  1
       1  1
      [torch.FloatTensor of size 2x2]
    """
    

    grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,

    y.backward()
    x.grad  # 可以看到变量梯度是累加的
    """
        Variable containing:
         2  2
         2  2
        [torch.FloatTensor of size 2x2]
    """
    

    所以要归零,

    x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法
    
    """
     0  0
     0  0
    [torch.FloatTensor of size 2x2]
    """
    

    对比Variable和Tensor的接口,相差无两,

    Variable和Tensor的接口近乎一致,可以无缝切换
    
    x = Variable(t.ones(4, 5))
    
    y = t.cos(x)                         # 传入Variable
    x_tensor_cos = t.cos(x.data)  # 传入Tensor
    
    print(y)
    print(x_tensor_cos)
    
    """
    Variable containing:
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    
    
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    """
    
  • 相关阅读:
    (4.38)sql server中的事务控制及try cache错误处理
    (4.37)sql server中的系统函数
    【3.5】mysql常用开发规则30条
    Linux学习笔记(17)Linux防火墙配置详解
    (5.16)mysql高可用系列——keepalived+mysql双主ha
    mysql online DDL
    (5.3.8)sql server升级打补丁
    python request
    python 模块被引用多次但是里面的全局表达式总共只会执行一次
    Protocol Buffer Basics: Python
  • 原文地址:https://www.cnblogs.com/hellcat/p/8439055.html
Copyright © 2020-2023  润新知