• 深度学习 Deep Learning UFLDL 最新Tutorial 学习笔记 4:Debugging: Gradient Checking


    1 Gradient Checking 说明

    前面我们已经实现了Linear Regression和Logistic Regression。关键在于代价函数Cost Function和其梯度Gradient的计算。

    在Gradient的计算中,我们一般採用推导出来的计算公式来进行计算。

    可是我们看到,推导出来的公式是复杂的。特别到后面的神经网络,更加复杂。这就产生了一个问题,我们怎样推断我们编写的程序就是计算出正确的Gradient呢?

    解决的方法就是通过数值计算的方法来估算Gradient然后与用公式计算出来的数据做对照,假设差距非常小,那么就说明我们的计算是对的。

    那么採用什么数值计算方法呢?
    事实上就是基于最主要的求导公式:

    ddθJ(θ)=limϵ0J(θ+ϵ)J(θϵ)2ϵ.

    我们取epsilon一个非常小的值,那么得到的数据就是导数的近似。


    因此
    g(θ)J(θ+EPSILON)J(θEPSILON)2×EPSILON.

    2 代码实现

    这里我们不须要自己Code,官方已经给出了代码。我们仅仅须要分析一下:
    这个代码用来计算gradient平均误差
    % 说明:grad_check 參数
    % fun为函数
    % num_checks 检查次数
    % varagin为參数列 var1,var2,var3...这个varagin必须放在function最后一个项
    function average_error = grad_check(fun, theta0, num_checks, varargin)
    
      delta=1e-3; 
      sum_error=0;
    
      fprintf(' Iter       i             err');
      fprintf('           g_est               g               f
    ')
    
      for i=1:num_checks
        T = theta0;
        j = randsample(numel(T),1);
        T0=T; T0(j) = T0(j)-delta;
        T1=T; T1(j) = T1(j)+delta;
    
        [f,g] = fun(T, varargin{:}); %因为fun是linear_regression或logistic_regression
        f0 = fun(T0, varargin{:});   %所以这里的varagin{:}參数为train.X,train.y 
        f1 = fun(T1, varargin{:});
    
        g_est = (f1-f0) / (2*delta);
        error = abs(g(j) - g_est);
    
        fprintf('% 5d  % 6d % 15g % 15f % 15f % 15f
    ', ...
                i,j,error,g(j),g_est,f);
    
        sum_error = sum_error + error;
      end
    
      average_error =sum_error/num_checks;
    

    那么在使用中。比方在ex1a_linreg.m中,能够这样使用:
    % Gradient Check
    average_error = grad_check(@linear_regression_vec,theta,50,train.X,train.y);
    fprintf('Average error :%f
    ',average_error);

    【本文为原创文章。转载请注明出处:blog.csdn.net/songrotek  欢迎交流哦QQ:363523441】

  • 相关阅读:
    微信开发者工具http申请图片变成https
    vue 中v-for img src 路径加载问题
    nodejs内置模块querystring中parse使用问题
    用git上传项目到github遇到的问题和解决方法
    页面刷新——微信小程序生命周期探索
    小程序项目复盘(三) 用全局变量传参的问题
    小程序项目复盘(二) wx.request异步请求处理
    小程序项目复盘(一)字符串处理问题
    微信小程序中我常用到的CSS3弹性盒子布局(flex)总结
    wx.request中POST方法传参问题,用到JSON.stringify()
  • 原文地址:https://www.cnblogs.com/llguanli/p/7158577.html
Copyright © 2020-2023  润新知