• Matlab实现单变量线性回归


    一、理论

    二、数据集

    6.1101,17.592
    5.5277,9.1302
    8.5186,13.662
    7.0032,11.854
    5.8598,6.8233
    8.3829,11.886
    7.4764,4.3483
    8.5781,12
    6.4862,6.5987
    5.0546,3.8166
    5.7107,3.2522
    14.164,15.505
    5.734,3.1551
    8.4084,7.2258
    5.6407,0.71618
    5.3794,3.5129
    6.3654,5.3048
    5.1301,0.56077
    6.4296,3.6518
    7.0708,5.3893
    6.1891,3.1386
    20.27,21.767
    5.4901,4.263
    6.3261,5.1875
    5.5649,3.0825
    18.945,22.638
    12.828,13.501
    10.957,7.0467
    13.176,14.692
    22.203,24.147
    5.2524,-1.22
    6.5894,5.9966
    9.2482,12.134
    5.8918,1.8495
    8.2111,6.5426
    7.9334,4.5623
    8.0959,4.1164
    5.6063,3.3928
    12.836,10.117
    6.3534,5.4974
    5.4069,0.55657
    6.8825,3.9115
    11.708,5.3854
    5.7737,2.4406
    7.8247,6.7318
    7.0931,1.0463
    5.0702,5.1337
    5.8014,1.844
    11.7,8.0043
    5.5416,1.0179
    7.5402,6.7504
    5.3077,1.8396
    7.4239,4.2885
    7.6031,4.9981
    6.3328,1.4233
    6.3589,-1.4211
    6.2742,2.4756
    5.6397,4.6042
    9.3102,3.9624
    9.4536,5.4141
    8.8254,5.1694
    5.1793,-0.74279
    21.279,17.929
    14.908,12.054
    18.959,17.054
    7.2182,4.8852
    8.2951,5.7442
    10.236,7.7754
    5.4994,1.0173
    20.341,20.992
    10.136,6.6799
    7.3345,4.0259
    6.0062,1.2784
    7.2259,3.3411
    5.0269,-2.6807
    6.5479,0.29678
    7.5386,3.8845
    5.0365,5.7014
    10.274,6.7526
    5.1077,2.0576
    5.7292,0.47953
    5.1884,0.20421
    6.3557,0.67861
    9.7687,7.5435
    6.5159,5.3436
    8.5172,4.2415
    9.1802,6.7981
    6.002,0.92695
    5.5204,0.152
    5.0594,2.8214
    5.7077,1.8451
    7.6366,4.2959
    5.8707,7.2029
    5.3054,1.9869
    8.2934,0.14454
    13.394,9.0551
    5.4369,0.61705

    三、代码实现

    clear  all; 
    clc;
    data = load('ex1data1.txt');
    X = data(:, 1); y = data(:, 2);
    m = length(y); % number of training examples
    plot(X,y,'rx');
    
    %% =================== Part 3: Gradient descent ===================
    fprintf('Running Gradient Descent ...
    ')
    
    %为什么加上一列1,为了算J时候,theta0 乘以1
    X = [ones(m, 1), data(:,1)]; % Add a column of ones to x
    theta = zeros(2, 1); % initialize fitting parameters
    
    % Some gradient descent settings
    iterations = 1500;
    alpha = 0.01;
    
    % compute and display initial cost
    computeCost(X, y, theta)
    
    % run gradient descent
    [theta, J_history]= gradientDescent(X, y, theta, alpha, iterations);
    
    hold on; % keep previous plot visible
    plot(X(:,2), X*theta, '-')
    legend('Training data', 'Linear regression')
    hold off % don't overlay any more plots on this figure
    
    % Predict values for population sizes of 35,000 and 70,000
    predict1 = [1, 3.5] *theta;
    fprintf('For population = 35,000, we predict a profit of %f
    ',...
        predict1*10000);
    predict2 = [1, 7] * theta;
    fprintf('For population = 70,000, we predict a profit of %f
    ',...
        predict2*10000);
    
    % Grid over which we will calculate J
    theta0_vals = linspace(-10, 10, 100);
    theta1_vals = linspace(-1, 4, 100);
    
    % initialize J_vals to a matrix of 0's
    J_vals = zeros(length(theta0_vals), length(theta1_vals));
    
    % Fill out J_vals
    for i = 1:length(theta0_vals)
        for j = 1:length(theta1_vals)
    	  t = [theta0_vals(i); theta1_vals(j)];    
    	  J_vals(i,j) = computeCost(X, y, t);
        end
    end
    
    
    % Because of the way meshgrids work in the surf command, we need to 
    % transpose J_vals before calling surf, or else the axes will be flipped
    J_vals = J_vals';
    % Surface plot
    figure;
    surf(theta0_vals, theta1_vals, J_vals)
    xlabel('	heta_0'); ylabel('	heta_1');
    
    % Contour plot
    figure;
    % Plot J_vals as 15 contours spaced logarithmically between 0.01 and 100
    %以10为底的指数 logspace(-2, 3, 20)坐标值标注范围以及间距
    contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20))
    xlabel('	heta_0'); ylabel('	heta_1');
    hold on;
    plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);
    

      ...................

    function J = computeCost(X, y, theta)
    
    m = length(y); % number of training examples 
    J = 0;
    
    for i=1:m
        J = J +(theta(1)*X(i,1) + theta(2)*X(i,2) - y(i))^2;  
    end
    % 除以2m是为了在更新参数的时候 好算   2因为J是二次,求骗到后产生系数2,
    %m是为了不让J 过大(i=1:m已经是求偏导第二部的m、项了)
    J = J/(m*2);
    end
    

      ......

    function [theta, J_history] = gradientDescent(X, y, theta, alpha, num_iters)
    
    m = length(y); % number of training examples
    J_history = zeros(num_iters, 1);
    J_1 = 0;% 偏导数J_1, J_2
    J_2 = 0;
    for iter = 1:num_iters
        for i = 1:m
            J_1 = J_1 + theta(1)*X(i,1) + theta(2)*X(i,2) - y(i);
            J_2 = J_2 + (theta(1)*X(i,1) + theta(2)*X(i,2) - y(i)) * X(i,2);
        end
        %J中的m 没有在上面的for内除,因为只除以一次就够了
        J_1 = J_1/m;
        J_2 = J_2/m;
    %     temp1 = theta(1) - alpha * J_1;
    %     temp2 = theta(2) - alpha * J_2;
    %     theta(1) = temp1;
    %     theta(2) = temp2;
        theta(1) = theta(1) - alpha * J_1;
        theta(2) = theta(2) - alpha * J_2;
        J_history(iter) = computeCost(X, y, theta);  
    %     save J_history J_history
    end
    end
    

    四、运行结果

  • 相关阅读:
    利用Navicat向MySQL数据库中批量插入多条记录的方法
    《Spring MVC+MyBatis快速开发与项目实战》-黄文毅2019:一书的源码和配套视频下载地址
    MySQL数据库建库时SQL语句中数据库名、表名用引号的问题以及COLLATE utf8_general_ci的含义
    [转]层行列和经纬度坐标之间的相互转化方法(谷歌地图)
    [Web 前端] VML、SVG、Canvas简介
    [Android Pro] 完美解决 No toolchains found in the NDK toolchains folder for ABI with prefix: mips64el-linux-android
    [Android Pro] so 动态加载—解决sdk过大问题
    [Android Pro] https://blog.csdn.net/gaugamela/article/details/79143309
    [web前端] 去哪儿网前端架构师司徒正美:如何挑选适合的前端框架?
    [web前端] yarn和npm命令使用
  • 原文地址:https://www.cnblogs.com/hxsyl/p/4884607.html
Copyright © 2020-2023  润新知