• 线性拟合-实验报告


    一.实验方法:

         1最小二乘法

         2梯度下降法

    二.公式推导

     

    最小二乘

     

    用线性函数h ax=a0+a1*x来拟合y=fx);

     

    构造代价函数Ja):

     

     

     代价函数分别对a0a1求偏导,连个偏导数都等于0成为两个方程,两个方程联合求解得到a0a1

     

    梯度下降

     

     构造代价函数Ja),Ja)对a0a1分别求偏导得到梯度,

     

    J(a)/a0=n*a0+a1*sumx-sumy;

     

    J(a)/a1=a1*sumx*sumx+a0*sumx-sumx*sumy;

     

        tidu_a0=n*a0+a1*sumx-sumy;

     

        tidu_a1=a1*sumxx+a0*sumx-sumxy;

     

    设置步长为l,迭代m

     

        delta_r=sqrt(tidu_a0*tidu_a0+tidu_a1*tidu_a1);

     

        a0=a0-l*(tidu_a0/tidu_r);

     

        a1=a1-l*(tidu_a1/tidu_r);

     

    每次迭代显示得到的直线和mse,并修订学习率

     

    %显示直线

     

        x2=[-0.1,1.1];

     

        y2=x2.*a1+a0;

     

        plot(x2,y2,'color',[1-i/m,1-i/m,1-i/m]);

     

       %显示错误

     

        error=0;

     

        for j=1:n

     

            error=error+(y(j)-(a1*x(j)+a0))*(y(j)-(a1*x(j)+a0));

     

        end

     

        mse=error/n;

     

        l=mse;

     

        mse

     

    三.matlab代码

     

    最小二乘法代码:

     

    %in是一个1002列的矩阵,两列分别为xy。用一条直线y=x*a+b拟合xy的关系;

     

    %用最小二乘法计算ab

     

    x=in(1:100,1);y=in(1:100,2);sumx=0;sumy=0;sumxx=0;sumyy=0;sumxy=0;for i=1:1:100 sumx=sumx+x(i); sumy=sumy+y(i); sumxx=sumxx+x(i)*x(i); sumyy=sumyy+y(i)*y(i); sumxy=sumxy+x(i)*y(i);endplot(in(:,1),in(:,2),'r.'); %用红色的点画出100个样本点hold on; %保留当前绘图,不被下次绘图遮盖n=100;[b,a]=solve('n*a0+a1*sumx=sumy','a0*sumx+a1*sumxx=sumxy','a0','a1'); 

     

    %解二元一次方程组,未知数为a0a1,结果返回给baa=eval(a); 

     

    %evalstr),把str当做一条语句执行b=eval(b);x2=[0,1]; 

     

    %知道解析式y=a*x+b,画直线的方法y2=x2.*a+b; 

     

    因为x2是一个向量,所以用x2.表示plot(x2,y2); 

     

    %制动化一条以x2x,以y2y的直线 mse=0;error=0;for i=1:n error=error+(y(i)-(a*x(i)+b))*(y(i)-(a*x(i)+b));endmse=error/n;mse

     

    梯度下降法代码:

     

    x=in(1:100,1);

     

    y=in(1:100,2);

     

    sumx=0;

     

    sumy=0;

     

    sumxx=0;

     

    sumyy=0;

     

    sumxy=0;

     

    for i=1:1:100

     

        sumx=sumx+x(i);

     

        sumy=sumy+y(i);

     

        sumxx=sumxx+x(i)*x(i);

     

        sumyy=sumyy+y(i)*y(i);

     

        sumxy=sumxy+x(i)*y(i);

     

    end

     

    plot(in(:,1),in(:,2),'r.');    

     

    hold on;

     

    a0=2;

     

    a1=1;

     

    l=0.5;

     

    n=100;

     

    m=50;

     

    for i=0:1:m

     

        tidu_a0=n*a0+a1*sumx-sumy;

     

        tidu_a1=a1*sumxx+a0*sumx-sumxy;

     

        tidu_r=sqrt(tidu_a0*tidu_a0+tidu_a1*tidu_a1);

     

        a0=a0-l*(tidu_a0/tidu_r);

     

        a1=a1-l*(tidu_a1/tidu_r);

     

        

     

        x2=[-0.1,1.1];

     

        y2=x2.*a1+a0;

     

        plot(x2,y2,'color',[1-i/m,1-i/m,1-i/m]);

     

       

     

        error=0;

     

        for j=1:n

     

            error=error+(y(j)-(a1*x(j)+a0))*(y(j)-(a1*x(j)+a0));

     

        end

     

        mse=error/n;

     

        l=mse;

     

        mse

     

    end

     

     

     

    四.运行结果

     

    最小二乘法结果

     

     

    2梯度下降法结果

     

     

    五.误差

     

    最小二乘法

     

    A1=3.679365985769617

     

    A0=--1.030876273676726

     

    均方误差mse=0.0429

     

    2梯度下降法

     

    (起点为a0=2a1=1;迭代次数为50

     

    A1=3.67860477725630

     

    A0=-1.00565713447357

     

    均方误差mse =0.0436

     

    使用mesh查看损失函数与a0,a1的关系:

    mesh显示mse的代码:

    x=in(1:100,1);
    y=in(1:100,2);
    sumx=0;
    sumy=0;
    sumxx=0;
    sumyy=0;
    sumxy=0;
    for i=1:1:100
        sumx=sumx+x(i);
        sumy=sumy+y(i);
        sumxx=sumxx+x(i)*x(i);
        sumyy=sumyy+y(i)*y(i);
        sumxy=sumxy+x(i)*y(i);
    end
    n=100;
    a0=-4.5:0.1:2.5;
    a1=0:0.1:7;
    [A0,A1]=meshgrid(a0,a1);
    Z=(sumyy+A0.*A0.*n+A1.*A1.*sumxx+A0.*A1.*2*sumx-A0.*2*sumy-A1.*2*sumxy)/n;
    mesh(A0,A1,Z);
    xlabel('a0轴');
    ylabel('a1轴');
    zlabel('mse轴');
    title('梯度下降');

     

     

    六.附数据   in.txt

     

    0.9005 1.9113

     

    0.4480 0.9218

     

    0.2689 -0.4654

     

    0.5538 1.4667

     

    0.1788 -0.2393

     

    0.8597 1.7048

     

    0.2320 -0.2135

     

    0.1681 -0.2549

     

    0.0267 -1.0928

     

    0.3224 0.2985

     

    0.5552 0.7931

     

    0.8245 2.0172

     

    0.8042 2.2273

     

    0.0244 -0.8888

     

    0.3715 0.5687

     

    0.4919 0.7795

     

    0.4661 0.5348

     

    0.0417 -0.7969

     

    0.6170 1.2403

     

    0.5780 1.5113

     

    0.2988 -0.1120

     

    0.4357 0.5782

     

    0.1366 -0.8407

     

    0.2997 0.3807

     

    0.7614 1.8959

     

    0.0353 -0.6399

     

    0.2695 -0.1072

     

    0.9963 2.7233

     

    0.4469 0.8604

     

    0.1528 -0.5472

     

    0.8862 2.3398

     

    0.0314 -1.2190

     

    0.1160 -0.6832

     

    0.2509 -0.1495

     

    0.7597 1.6176

     

    0.8983 1.9552

     

    0.2234 -0.1696

     

    0.6733 1.4859

     

    0.8188 2.1008

     

    0.9489 2.6517

     

    0.8743 2.0069

     

    0.3937 0.4557

     

    0.9370 2.4427

     

    0.4369 0.8025

     

    0.1625 -0.2676

     

    0.3098 -0.0641

     

    0.6811 1.1038

     

    0.9341 2.2406

     

    0.9474 2.6501

     

    0.5991 1.1617

     

    0.9489 2.4170

     

    0.4040 0.3019

     

    0.0410 -1.0271

     

    0.2938 0.1261

     

    0.0319 -0.7842

     

    0.8645 2.2468

     

    0.4325 0.5829

     

    0.0928 -0.4767

     

    0.1378 -0.5801

     

    0.2420 -0.1617

     

    0.2230 -0.4245

     

    0.8677 2.1976

     

    0.7642 1.7447

     

    0.3447 0.0178

     

    0.3848 0.4811

     

    0.5949 1.2016

     

    0.5351 1.3388

     

    0.3336 0.2838

     

    0.8547 2.2127

     

    0.2656 -0.1061

     

    0.9339 2.1840

     

    0.3898 0.1515

     

    0.6831 1.5417

     

    0.2750 0.2706

     

    0.0280 -0.8750

     

    0.9406 2.6179

     

    0.5340 0.8242

     

    0.6712 1.4927

     

    0.6075 1.1417

     

    0.7509 1.5665

     

    0.9813 2.7267

     

    0.7277 1.5830

     

    0.8573 1.4756

     

    0.9918 3.0038

     

    0.7595 1.6970

     

    0.1460 -0.4369

     

    0.3263 0.0628

     

    0.0288 -0.9162

     

    0.6946 1.4643

     

    0.9588 2.4821

     

    0.7290 1.5572

     

    0.7368 1.4520

     

    0.1746 -0.4995

     

    0.3554 0.3202

     

    0.5746 1.0338

     

    0.4599 0.9678

     

    0.8337 2.6507

     

    0.8154 1.8128

     

    0.3240 -0.0295

     

    0.4617 0.4441

     

  • 相关阅读:
    Unzip 解压报错
    Linux ftp安装
    关于vsftp出现Restarting vsftpd (via systemctl): Job for vsftpd.service failed because the control 的解决办法
    ASP.NET开发知识总结
    移动端开发调试方法总结
    移动H5优化指南
    基于windows下,node.js之npm
    微服务理解
    SQL Server 触发器
    jQuery验证控件jquery.validate.js使用说明+中文API
  • 原文地址:https://www.cnblogs.com/huangzq/p/3432553.html
Copyright © 2020-2023  润新知