• 10.损失函数及其梯度(均方差梯度[使用线性函数y=w*x+b作为激活函数])


    1.MSE(均方差)梯度

    (1)均方差MSE

    (2)MSE求梯度

     

     【注】例如网络形式为线性感知机:ƒ(x)=w*x+b这里只是举例,具体用什么样的函数需要根据实际的网络结构。

    对w求导则是:Δƒw(w)/Δw

    对b求导则是:Δƒb(b)/Δb

    (3)均方差在pytorch中如何求梯度

    (3.1.1)torch.autograd.grad(loss,[w1,w2...........])

     

     【注】pytorch中mse_loss的自动微分:

    F.mse_loss(label,pred) pred的为线性感知机中的w*x+b,label为x。

    torch.autograd.grad(mse,para)para为线性感知机中的w和b参数。其中第一个参数必须为维度为1长度为1的tensor。

    【注】只有浮点数型数据才能计算梯度,故上图中会出现23和24行下面的错误。requires_grad_()可以对tensor类型的数据进行更新,使其可以进行梯度运算。

    (3.1.2)loss.backward()

     (3.1.3)pytorch中损失函数求梯度的两种方法总结

     [注]两种方式返回值的形式不同:

    第一种为【w1 grad,w2 grad】

    第二种为w1.grad或者w2.grad等。

    [注]可以对tensor类型的数据进行.norm查看tensor的norm,也可以对梯度信息进行.norm。

  • 相关阅读:
    C++ 单例模式
    单链表快速排序
    美团后台面经
    排序算法及其优化总结
    (转)再谈互斥量与环境变量
    互斥锁和自旋锁
    算法题总结----数组(二分查找)
    Linux里的2>&1的理解
    Ubuntu下开启mysql远程访问
    说说eclipse调优,缩短启动时间
  • 原文地址:https://www.cnblogs.com/jiafeng1996/p/15050284.html
Copyright © 2020-2023  润新知