• 『PyTorch』第三弹重置_Variable对象


    『PyTorch』第三弹_自动求导

    torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

    Varibale包含三个属性:

    • data:存储了Tensor,是本体的数据
    • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
    • grad_fn:指向Function对象,用于反向传播的梯度计算之用

    data

    import torch as t
    from torch.autograd import Variable
    
    x = Variable(t.ones(2, 2), requires_grad = True)
    x  # 实际查询的是.data,是个Tensor
    

    实际上查询x和查询x.data返回结果一致,

    Variable containing:

     1 1

     1 1

     [torch.FloatTensor of size 2x2]

    梯度求解

    构建一个简单的方程:y = x[0,0] + x[0,1] + x[1,0] + x[1,1],Variable的运算结果也是Variable,但是,中间结果反向传播中不会被求导()

    这和TensorFlow不太一致,TensorFlow中中间运算果数据结构均是Tensor,

    y = x.sum()
    
    y
    """
      Variable containing:
       4
      [torch.FloatTensor of size 1]
    """
    

    可以查看目标函数的.grad_fn方法,它用来求梯度,

    y.grad_fn
    """
        <SumBackward0 at 0x18bcbfcdd30>
    """
    
    y.backward()  # 反向传播
    x.grad  # Variable的梯度保存在Variable.grad中
    """
      Variable containing:
       1  1
       1  1
      [torch.FloatTensor of size 2x2]
    """
    

    grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,

    y.backward()
    x.grad  # 可以看到变量梯度是累加的
    """
        Variable containing:
         2  2
         2  2
        [torch.FloatTensor of size 2x2]
    """
    

    所以要归零,

    x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法
    
    """
     0  0
     0  0
    [torch.FloatTensor of size 2x2]
    """
    

    对比Variable和Tensor的接口,相差无两,

    Variable和Tensor的接口近乎一致,可以无缝切换
    
    x = Variable(t.ones(4, 5))
    
    y = t.cos(x)                         # 传入Variable
    x_tensor_cos = t.cos(x.data)  # 传入Tensor
    
    print(y)
    print(x_tensor_cos)
    
    """
    Variable containing:
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    
    
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    """
    
  • 相关阅读:
    查询计划Hash和查询Hash
    执行计划的重用
    执行计划组件、组件、老化
    执行计划的生成
    查询反模式
    T-SQL 公用表表达式(CTE)
    SQL 操作结果集 -并集、差集、交集、结果集排序
    SQL语句
    POJ 1821 单调队列+dp
    区间gcd问题 HDU 5869 离线+树状数组
  • 原文地址:https://www.cnblogs.com/hellcat/p/8439055.html
Copyright © 2020-2023  润新知