• pytorch requires_grad = True的意思


    计算图通常包含两种元素,一个是 tensor,另一个是 Function。张量 tensor 不必多说,但是大家可能对 Function 比较陌生。这里 Function 指的是在计算图中某个节点(node)所进行的运算,比如加减乘除卷积等等之类的,Function 内部有 forward() 和 backward() 两个方法,分别应用于正向、反向传播。

    当我们创建一个张量 (tensor) 的时候,如果没有特殊指定的话,那么这个张量是默认是不需要求导的。们在训练一个网络的时候,我们从 DataLoader 中读取出来的一个 mini-batch 的数据,这些输入默认是不需要求导的,其次,网络的输出我们没有特意指明需要求导吧,Ground Truth 我们也没有特意设置需要求导吧。这么一想,哇,那我之前的那些 loss 咋还能自动求导呢?其实原因就是上边那条规则,虽然输入的训练数据是默认不求导的,但是,我们的 model 中的所有参数,它默认是求导的,这么一来,其中只要有一个需要求导,那么输出的网络结果必定也会需要求的。来看个实例:

    input = torch.randn(8, 3, 50, 100)
    print(input.requires_grad)
    # False
    
    net = nn.Sequential(nn.Conv2d(3, 16, 3, 1),
                        nn.Conv2d(16, 32, 3, 1))
    for param in net.named_parameters():
        print(param[0], param[1].requires_grad)
    # 0.weight True
    # 0.bias True
    # 1.weight True
    # 1.bias True
    
    output = net(input)
    print(output.requires_grad)
    # True
    

    在写代码的过程中,不要把网络的输入和 Ground Truth 的 requires_grad 设置为 True。虽然这样设置不会影响反向传播,但是需要额外计算网络的输入和 Ground Truth 的导数,增大了计算量和内存占用不说,这些计算出来的导数结果也没啥用。因为我们只需要神经网络中的参数的导数,用来更新网络,其余的导数都不需要。

    原文链接:https://zhuanlan.zhihu.com/p/67184419

  • 相关阅读:
    数据库连接JOIN
    Java面试金典
    Collections.sort详解
    Java复合优先于继承
    js算术运算符与数据类型转换
    js数组类型
    js对象类型
    CSS-API(CSS编程接口),CSSOM(css对象模型)
    从零开始--单片机十字路口交通灯控制实验
    matlab用双重循环实现费诺编码
  • 原文地址:https://www.cnblogs.com/zhang12345/p/16022622.html
Copyright © 2020-2023  润新知