• pytorch中的Variable() Learner


    函数简介

    torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现(tensor变成variable之后才能进行反向传播求梯度?用变量.backward()进行反向传播之后,var.grad中保存了var的梯度)

    x = Variable(tensor, requires_grad = True)

    Varibale包含三个属性:

    • data:存储了Tensor,是本体的数据
    • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
    • grad_fn:指向Function对象,用于反向传播的梯度计算之用

    用法:

    import torch
    from torch.autograd import Variable
     
    x = Variable(torch.one(2,2), requires_grad = True)
    print(x)#其实查询的是x.data,是个tensor

    举个例子求梯度:

    构建一个简单的方程:y = x[0,0] + x[0,1] + x[1,0] + x[1,1],Variable的运算结果也是Variable,但是,中间结果反向传播中不会被求导()

    这和TensorFlow不太一致,TensorFlow中中间运算果数据结构均是Tensor

    y = x.sum()
     
    y
    """
      Variable containing:
       4
      [torch.FloatTensor of size 1]
    """
     
     
    #可以查看目标函数的.grad_fn方法,它用来求梯度
    y.grad_fn
    """
        <SumBackward0 at 0x18bcbfcdd30>
    """
     
    y.backward()  # 反向传播
    x.grad  # Variable的梯度保存在Variable.grad中
    """
      Variable containing:
       1  1
       1  1
      [torch.FloatTensor of size 2x2]
    """
     
     
    #grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,
    y.backward()
    x.grad  # 可以看到变量梯度是累加的
    """
        Variable containing:
         2  2
         2  2
        [torch.FloatTensor of size 2x2]
    """
     
    #所以要归零
    x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法
     
    """
     0  0
     0  0
    [torch.FloatTensor of size 2x2]
    """
     
     
    #对比Variable和Tensor的接口,相差无两
    x = Variable(torch.ones(4, 5))
     
    y = torch.cos(x)                         # 传入Variable
    x_tensor_cos = torch.cos(x.data)  # 传入Tensor
     
    print(y)
    print(x_tensor_cos)
     
    """
    Variable containing:
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
     
     
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]

    参考:

    https://blog.csdn.net/u012370185/article/details/94391428

  • 相关阅读:
    技术实践 | 聊聊网易云信的信令网络库实践
    打破传统降噪技术 看网易云信在语音降噪的实践应用
    聊聊前端日志库在 SaaS 产品中的应用与设计
    WebRTC 系列之音频会话管理
    简单五步,轻松构建本土「Clubhouse」
    网易云信服务监控平台实践
    基于 Elasticsearch 的数据报表方案
    基于 WebRTC 实现自定义编码分辨率发送
    Python 设计模式—原型模式
    网络层—简单的面试问题
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15709599.html
Copyright © 2020-2023  润新知