• 【cs229-Lecture2】Linear Regression with One Variable (Week 1)(含测试数据和源码)


    从Ⅱ到Ⅳ都在讲的是线性回归,其中第Ⅱ章讲得是简单线性回归(simple linear regression, SLR)(单变量),第Ⅲ章讲的是线代基础,第Ⅳ章讲的是多元回归(大于一个自变量)。

    本文的目的主要是对Ⅱ章中出现的一些算法进行实现,适合的人群为已经看完本章节Stanford课程的学者。本人只是一名初学者,尽可能以白话的方式来说明问题。不足之处,还请指正。

    在开始讨论具体步骤之前,首先给出简要的思维路线:

    1.拥有一个点集,为了得到一条最佳拟合的直线;

    2.通过“最小二乘法”来衡量拟合程度,得到代价方程;

    3.利用“梯度下降算法”使得代价方程取得极小值点;



    首先,介绍几个概念:

    回归在数学上来说是给定一个点集,能够用一条曲线去拟合之。如果这个曲线是一条直线,那就被称为线性回归;如果曲线是一条二次曲线,就被称为二次回归,回归还有很多的变种,如locally weighted回归,logistic回归等等。

    课程中得到的h就是线性回归方程:

    image


    下面,首先来介绍一下单变量的线性回归:

    问题是这样的:给定一个点集,找出一条直线去拟合,要求拟合的效果达到最佳(最佳拟合)。

    既然是直线,我们先假设直线的方程为:image

         如图:image

        点集有了,直线方程有了,接下来,我们要做的就是计算出imageimage,使得拟合效果达到最佳(最佳拟合)。

        那么,拟合效果的评判标准是什么呢?换句话说,我们需要知道一种对拟合效果的度量。

       在这里,我们提出“最小二乘法”:(以下摘自wiki)

    最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。

    利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。

    对于“最小二乘法”就不再展开讨论,只要知道他是一个度量标准,我们可以用它来评判计算出的直线方程是否达到了最佳拟合就够了。

    那么,回到问题上来,在单变量的线性回归中,这个拟合效果的表达式是利用最小二乘法将未知量残差平方和最小化

    image

    结合课程,定义了一个成本函数:

    image

    其实,到这里,要是把点集的具体数值代入到成本函数中,就已经完全抽象出了一个高等数学问题(解一个二元函数的最小值问题)。

    image

    其中,a,b,c,d,e,f均为已知。

    课程中介绍了一种叫“Gradient descent”的方法——梯度下降算法

    image

    两张图说明算法的基本思想:

    imageimage

    image

    所谓梯度下降算法(一种求局部最优解的方法),举个例子就好比你现在在一座山上,你想要尽快地到达山底(极小值点),这是一个下降的过程,这里就涉及到了两个问题:1)你下山的时候,跨多大的步子(当然,肯定不是越大越好,因为有一种可能就是你一步跨地太大,正好错过了极小的位置);2)你朝哪个方向跨步(注意,这个方向是不断变化的,你每到一个新的位置,要判断一下下一步朝那个方向走才是最好的,但是有一点可以肯定的是,要想尽快到达最低点,应从最陡的地方下山)。

    那么,什么时候算是你到了一个极小点呢,显然,当你所处的位置发生的变化不断减小,直至收敛于某一位置,就说明那个位置就是一个极小值点。

    so,我们来看image的变化,则我们需要让imageimage求偏导,倒数代表变化率。也就是要朝着对陡的地方下山(因为沿着最陡显然比较快),就得到了image的变化情况:image

    image

    image

    简化之后:

    image

    步长不宜过大或过小

    image

    梯度下降法是按下面的流程进行的:(转自:http://blog.sina.com.cn/s/blog_62339a2401015jyq.html

    1)首先对θ赋值,这个值可以是随机的,也可以让θ是一个全零的向量。

    2)改变θ的值,使得J(θ)按梯度下降的方向进行减少。

            为了方便大家的理解,首先给出单变量的例子:

           eg:求image的最小值。(注:image

    image

           java代码如下:

    ·

    package OneVariable;
    
    public class OneVariable{
        public static void main(String[] args){
        double e=0.00001;//定义迭代精度
        double alpha=0.5;//定义迭代步长
        double x=0;            //初始化x
        double y0=2*x*x+3*x+1;//与初始化x对应的y值
        double y1=0;//定义变量,用于保存当前值
        while (true)
        {
            x=x-alpha*(4.0*x+3.0);
            y1=2*x*x+3*x+1;
            if (Math.abs(y1-y0)<e)//如果2次迭代的结果变化很小,结束迭代
            {
                break;
            }
            y0=y1;//更新迭代的结果
        }
        System.out.println("Min(f(x))="+y0);
        System.out.println("minx="+x);
        }
    }
    
    //输出
    Min(f(x))=1.0
    minx=-1.5

    两个变量的时候,为了更清楚,给出下面的图:

    image

    这是一个表示参数θ与误差函数J(θ)的关系图,红色的部分是表示J(θ)有着比较高的取值,我们需要的是,能够让J(θ)的值尽量的低。也就是深蓝色的部分。θ0,θ1表示θ向量的两个维度。

    在上面提到梯度下降法的第一步是给θ给一个初值,假设随机给的初值是在图上的十字点。

    然后我们将θ按照梯度下降的方向进行调整,就会使得J(θ)往更低的方向进行变化,如图所示,算法的结束将是在θ下降到无法继续下降为止。

    image

    当然,可能梯度下降的最终点并非是全局最小点,可能是一个局部最小点,可能是下面的情况:

    image

    上面这张图就是描述的一个局部最小点,这是我们重新选择了一个初始点得到的,看来我们这个算法将会在很大的程度上被初始点的选择影响而陷入局部最小点 

    一个很重要的地方值得注意的是,梯度是有方向的,对于一个向量θ,每一维分量θi都可以求出一个梯度的方向,我们就可以找到一个整体的方向,在变化的时候,我们就朝着下降最多的方向进行变化就可以达到一个最小点,不管它是局部的还是全局的。


    理论的知识就讲到这,下面,我们就用java去实现这个算法:

    梯度下降有两种:批量梯度下降和随机梯度下降。详见:http://blog.csdn.net/lilyth_lilyth/article/details/8973972

    测试数据就用课后题中的数据(ex1data1.txt),用matlab打开作图得到:

    image

    首先说明:以下源码是不正确的,具体为什么不正确我还没搞清楚!非常希望各位高手能够指正!

    测试数据及源码下载:http://pan.baidu.com/s/1mgiIVm4

    OneVariable.java
     1 package OneVariableVersion;
     2 
     3 import java.io.IOException;
     4 import java.util.List;
     5 
     6 
     7 /**
     8  * Linear Regression with One Variable
     9  * @author XBW
    10  * @date 2014年8月17日
    11  */
    12 
    13 public class OneVariable{
    14     public static final Double e=0.00001;
    15     public static List<Data> DS;
    16     public static Double step;
    17     public static Double m;
    18     
    19     /**
    20      * 计算当前参数是否符合
    21      * @param ans
    22      * @param datalist
    23      * @return
    24      */
    25     public static Ans calc(Ans ans){
    26         Double costfun;
    27         do{
    28             costfun=calcAccuracy(ans);
    29             ans=update(ans);
    30             step*=0.3;
    31         }while(Math.abs(costfun-calcAccuracy(ans))>e);
    32         ans.ifConvergence=true;
    33         return ans;
    34     }
    35     
    36     /**
    37      * 判断当前ans是否满足精度,y=t0+t1*x
    38      * @param ans
    39      * @return
    40      */
    41     public static Double calcAccuracy(Ans ans){
    42         Double cost=0.0;
    43         Double tmp;
    44         for(int i=0;i<m;i++){
    45             tmp=DS.get(i).y-(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]);
    46             cost+=tmp*tmp;
    47         }
    48         cost/=(2*m);
    49         return cost;
    50     }
    51     
    52     /**
    53      * 更新ans
    54      * @param ans,学习速率为step,m为数据量
    55      * @return
    56      */
    57     public static Ans update(Ans ans){
    58         Double[] tmp=new Double[100] ;
    59         for(int i=0;i<2;i++){
    60             tmp[i]=ans.theta[i]-step*fun(ans,i);
    61         }
    62         for(int i=0;i<2;i++){
    63             ans.theta[i]=tmp[i];
    64         }
    65         return ans;
    66     }
    67     
    68     /**
    69      * 计算偏导
    70      * @return
    71      */
    72     public static Double fun(Ans ans,int xi){
    73         Double ret = 0.0;
    74         for(int i=0;i<m;i++){
    75             ret+=(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]-DS.get(i).y)*DS.get(i).x[xi];
    76         }
    77         ret/=m;
    78         return ret;        
    79     }
    80     
    81     public static void main(String[] args) throws IOException{
    82         DS=new DataSet().ds;
    83         step=1.0;        
    84         m=(double)DS.size();
    85         
    86         
    87         Double[] theta={0.0,0.0};                     //初始设定theta0=0,theta1=0
    88         Ans ans=new Ans(theta,false);
    89         Ans answer;
    90         answer=calc(ans);
    91         System.out.println("theta1= "+answer.theta[0]+"      theta2="+answer.theta[1]);
    92     }
    93 }

    DataSet.java

     1 package OneVariableVersion;
     2 
     3 import java.io.BufferedReader;
     4 import java.io.File;
     5 import java.io.FileReader;
     6 import java.io.IOException;
     7 import java.util.ArrayList;
     8 import java.util.List;
     9 
    10 
    11 /**
    12  * 数据处理
    13  * @author XBW
    14  * @date 2014年8月17日
    15  */
    16 
    17 public class DataSet{
    18     String defaultpath="D:\MachineLearning\StanfordbyAndrewNg\II.LinearRegressionwithOneVariable(Week1)\homework\ex1data1.txt";
    19     
    20     List<Data> ds=new ArrayList<Data>();
    21     
    22     public DataSet() throws IOException{
    23         File dataset=new File(defaultpath);
    24         BufferedReader br = new BufferedReader(new FileReader(dataset));
    25         String tsing;
    26         while((tsing=br.readLine())!=null){
    27             String[] dlist=tsing.split(",");
    28             Data dtmp=new Data(Double.parseDouble(dlist[0]),Double.parseDouble(dlist[1]));
    29             this.ds.add(dtmp);
    30         }
    31         br.close();
    32     }
    33 }

    Ans.java

     1 package OneVariableVersion;
     2 
     3 /**
     4  * 保存结果,y=t0+t1*x
     5  * @author XBW
     6  * @date 2014年8月17日
     7  */
     8 
     9 public class Ans {
    10     Double[] theta;
    11     boolean ifConvergence;
    12     
    13     public Ans(Double[] tmp,boolean ifCon){
    14         this.theta=tmp;
    15         this.ifConvergence=ifCon;
    16     }
    17 }

    Data.java

     1 package OneVariableVersion;
     2 
     3 
     4 /**
     5  * 一条数据
     6  * @author XBW
     7  * @date 2014年8月17日
     8  */
     9 public class Data {
    10     Double[] x=new Double[2];
    11     Double y;
    12     
    13     public Data(Double xtmp,Double ytmp){
    14         this.x[0]=1.0;
    15         this.x[1]=xtmp;
    16         this.y=ytmp;
    17     }
    18 }

    总结:写代码的时候有几个讲究:

    1. 步长是否需要动态变化,按照coursera公开课上讲的是不必要动态改变的,因为偏导数会越来越小,但在实际情况下,按照一定的比值缩小或者自己定义一种缩小的方式可能是有必要的,所以具体情况具体分析;
    2. 初始步长的设定也是很重要的,大了就不会得到结果,因为发散了;步长越大,下降速率越快,但是也会导致震荡,所以,还是哪句话:具体问题具体分析;






                If you have any questions about this article, welcome to leave a message on the message board.



    Brad(Bowen) Xu
    E-Mail : maxxbw1992@gmail.com


  • 相关阅读:
    echarts二维坐标这样写出立体柱状图
    echarts中使图表循环显示tooltip-封装tooltip的方法轮询显示图表数据
    webpack--运行npm run dev自动打开浏览器以及热加载
    exports、module.exports和export、export default到底是咋回事,区别在哪里
    H5页面判断客户端是iOS或是Android并跳转对应链接唤起APP
    关于页面锚点跳转以及万能锚点跳转插件
    echarts Map 省份文字居中,即对应地图中心位置
    Moment.js 常见用法,常见API
    Tomcat 不加载图片验证码
    Cglib 动态代理
  • 原文地址:https://www.cnblogs.com/XBWer/p/3912792.html
Copyright © 2020-2023  润新知