• 10.损失函数及其梯度(均方差梯度[使用线性函数y=w*x+b作为激活函数])


    1.MSE(均方差)梯度

    (1)均方差MSE

    (2)MSE求梯度

     

     【注】例如网络形式为线性感知机:ƒ(x)=w*x+b这里只是举例,具体用什么样的函数需要根据实际的网络结构。

    对w求导则是:Δƒw(w)/Δw

    对b求导则是:Δƒb(b)/Δb

    (3)均方差在pytorch中如何求梯度

    (3.1.1)torch.autograd.grad(loss,[w1,w2...........])

     

     【注】pytorch中mse_loss的自动微分:

    F.mse_loss(label,pred) pred的为线性感知机中的w*x+b,label为x。

    torch.autograd.grad(mse,para)para为线性感知机中的w和b参数。其中第一个参数必须为维度为1长度为1的tensor。

    【注】只有浮点数型数据才能计算梯度,故上图中会出现23和24行下面的错误。requires_grad_()可以对tensor类型的数据进行更新,使其可以进行梯度运算。

    (3.1.2)loss.backward()

     (3.1.3)pytorch中损失函数求梯度的两种方法总结

     [注]两种方式返回值的形式不同:

    第一种为【w1 grad,w2 grad】

    第二种为w1.grad或者w2.grad等。

    [注]可以对tensor类型的数据进行.norm查看tensor的norm,也可以对梯度信息进行.norm。

  • 相关阅读:
    Struts2框架(二)
    Struts2框架(一)
    jsp定义全局的错误处理
    BeanUtils的使用、Java中的路径问题
    IntelliJ IDEA 14.1.4(Window)快捷键
    Log4J日志组件
    注解
    反射
    泛型
    AndroidStudio开发工具快捷键(转)
  • 原文地址:https://www.cnblogs.com/jiafeng1996/p/15050284.html
Copyright © 2020-2023  润新知