• sklearn线性回归实现房价预测模型


    题目要求

    建立房价预测模型:利用ex1data1.txt单特征)和ex1data2.txt多特征)中的数据,进行线性回归和预测。

    作散点图可知,数据大致符合线性关系,故暂不研究其他形式的回归。

    两份数据放在最后。

    单特征线性回归

    ex1data1.txt中的数据是单特征,作一个简单的线性回归即可:(y=ax+b)

    根据是否分割数据,产生两种方案:方案一,所有样本都用来训练和预测;方案二,一部分样本用来训练,一部分用来检验模型。

    方案一

    对ex1data1.txt中的数据进行线性回归,所有样本都用来训练和预测。

    代码实现如下:

    """
        对ex1data1.txt中的数据进行线性回归,所有样本都用来训练和预测
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LinearRegression
    from sklearn.metrics import mean_squared_error, r2_score
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    # 数据格式:城市人口,食品经销商利润
    
    # 读取数据
    data = np.loadtxt('ex1data1.txt', delimiter=',')
    data_X = data[:, 0]
    data_y = data[:, 1]
    
    # 训练模型
    model = LinearRegression()
    model.fit(data_X.reshape([-1, 1]), data_y)
    
    # 利用模型进行预测
    y_predict = model.predict(data_X.reshape([-1, 1]))
    
    # 结果可视化
    plt.scatter(data_X, data_y, color='red')
    plt.plot(data_X, y_predict, color='blue', linewidth=3)
    plt.xlabel('城市人口')
    plt.ylabel('食品经销商利润')
    plt.title('线性回归——城市人口与食品经销商利润的关系')
    plt.show()
    
    # 模型参数
    print(model.coef_)
    print(model.intercept_)
    # MSE
    print(mean_squared_error(data_y, y_predict))
    # R^2
    print(r2_score(data_y, y_predict))
    
    

    结果如下:

    由下可知函数形式以及(R^2)为0.70

    [1.19303364]
    -3.89578087831185
    8.953942751950358
    0.7020315537841397
    

    ex1data1_1.png

    方案二

    对ex1data1.txt中的数据进行线性回归,部分样本用来训练,部分样本用来预测。

    实现如下:

    """
        对ex1data1.txt中的数据进行线性回归,部分样本用来训练,部分样本用来预测
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error, r2_score
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    # 数据格式:城市人口,食品经销商利润
    
    # 读取数据
    data = np.loadtxt('ex1data1.txt', delimiter=',')
    data_X = data[:, 0]
    data_y = data[:, 1]
    
    # 数据分割
    X_train, X_test, y_train, y_test = train_test_split(data_X, data_y)
    
    # 训练模型
    model = LinearRegression()
    model.fit(X_train.reshape([-1, 1]), y_train)
    
    # 利用模型进行预测
    y_predict = model.predict(X_test.reshape([-1, 1]))
    
    # 结果可视化
    plt.scatter(X_test, y_test, color='red')  # 测试样本
    plt.plot(X_test, y_predict, color='blue', linewidth=3)
    plt.xlabel('城市人口')
    plt.ylabel('食品经销商利润')
    plt.title('线性回归——城市人口与食品经销商利润的关系')
    plt.show()
    
    # 模型参数
    print(model.coef_)
    print(model.intercept_)
    # MSE
    print(mean_squared_error(y_test, y_predict))
    # R^2
    print(r2_score(y_test, y_predict))
    
    

    结果如下

    由下可知函数形式以及(R^2)为0.80

    [1.21063939]
    -4.195481965945055
    5.994362667047617
    0.8095125123727652
    

    ex1data1_2.png

    多特征线性回归

    ex1data2.txt中的数据是二个特征,作一个最简单的多元(在此为二元)线性回归即可:(y=a_1x_1+a_2x_2+b)

    对ex1data2.txt中的数据进行线性回归,所有样本都用来训练和预测。

    代码实现如下:

    """
        对ex1data2.txt中的数据进行线性回归,所有样本都用来训练和预测
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LinearRegression
    from mpl_toolkits.mplot3d import Axes3D  # 不要去掉这个import
    from sklearn.metrics import mean_squared_error, r2_score
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    # 数据格式:城市人口,房间数目,房价
    
    # 读取数据
    data = np.loadtxt('ex1data2.txt', delimiter=',')
    data_X = data[:, 0:2]
    data_y = data[:, 2]
    
    # 训练模型
    model = LinearRegression()
    model.fit(data_X, data_y)
    
    # 利用模型进行预测
    y_predict = model.predict(data_X)
    
    # 结果可视化
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.scatter(data_X[:, 0], data_X[:, 1], data_y, color='red')
    ax.plot(data_X[:, 0], data_X[:, 1], y_predict, color='blue')
    ax.set_xlabel('城市人口')
    ax.set_ylabel('房间数目')
    ax.set_zlabel('房价')
    plt.title('线性回归——城市人口、房间数目与房价的关系')
    plt.show()
    
    # 模型参数
    print(model.coef_)
    print(model.intercept_)
    # MSE
    print(mean_squared_error(data_y, y_predict))
    # R^2
    print(r2_score(data_y, y_predict))
    
    

    结果如下:

    由下可知函数形式以及(R^2)为0.73

    [  139.21067402 -8738.01911233]
    89597.90954279748
    4086560101.205658
    0.7329450180289141
    

    ex1data2.png

    两份数据

    ex1data1.txt

    6.1101,17.592
    5.5277,9.1302
    8.5186,13.662
    7.0032,11.854
    5.8598,6.8233
    8.3829,11.886
    7.4764,4.3483
    8.5781,12
    6.4862,6.5987
    5.0546,3.8166
    5.7107,3.2522
    14.164,15.505
    5.734,3.1551
    8.4084,7.2258
    5.6407,0.71618
    5.3794,3.5129
    6.3654,5.3048
    5.1301,0.56077
    6.4296,3.6518
    7.0708,5.3893
    6.1891,3.1386
    20.27,21.767
    5.4901,4.263
    6.3261,5.1875
    5.5649,3.0825
    18.945,22.638
    12.828,13.501
    10.957,7.0467
    13.176,14.692
    22.203,24.147
    5.2524,-1.22
    6.5894,5.9966
    9.2482,12.134
    5.8918,1.8495
    8.2111,6.5426
    7.9334,4.5623
    8.0959,4.1164
    5.6063,3.3928
    12.836,10.117
    6.3534,5.4974
    5.4069,0.55657
    6.8825,3.9115
    11.708,5.3854
    5.7737,2.4406
    7.8247,6.7318
    7.0931,1.0463
    5.0702,5.1337
    5.8014,1.844
    11.7,8.0043
    5.5416,1.0179
    7.5402,6.7504
    5.3077,1.8396
    7.4239,4.2885
    7.6031,4.9981
    6.3328,1.4233
    6.3589,-1.4211
    6.2742,2.4756
    5.6397,4.6042
    9.3102,3.9624
    9.4536,5.4141
    8.8254,5.1694
    5.1793,-0.74279
    21.279,17.929
    14.908,12.054
    18.959,17.054
    7.2182,4.8852
    8.2951,5.7442
    10.236,7.7754
    5.4994,1.0173
    20.341,20.992
    10.136,6.6799
    7.3345,4.0259
    6.0062,1.2784
    7.2259,3.3411
    5.0269,-2.6807
    6.5479,0.29678
    7.5386,3.8845
    5.0365,5.7014
    10.274,6.7526
    5.1077,2.0576
    5.7292,0.47953
    5.1884,0.20421
    6.3557,0.67861
    9.7687,7.5435
    6.5159,5.3436
    8.5172,4.2415
    9.1802,6.7981
    6.002,0.92695
    5.5204,0.152
    5.0594,2.8214
    5.7077,1.8451
    7.6366,4.2959
    5.8707,7.2029
    5.3054,1.9869
    8.2934,0.14454
    13.394,9.0551
    5.4369,0.61705
    

    ex1data2.txt

    2104,3,399900
    1600,3,329900
    2400,3,369000
    1416,2,232000
    3000,4,539900
    1985,4,299900
    1534,3,314900
    1427,3,198999
    1380,3,212000
    1494,3,242500
    1940,4,239999
    2000,3,347000
    1890,3,329999
    4478,5,699900
    1268,3,259900
    2300,4,449900
    1320,2,299900
    1236,3,199900
    2609,4,499998
    3031,4,599000
    1767,3,252900
    1888,2,255000
    1604,3,242900
    1962,4,259900
    3890,3,573900
    1100,3,249900
    1458,3,464500
    2526,3,469000
    2200,3,475000
    2637,3,299900
    1839,2,349900
    1000,1,169900
    2040,4,314900
    3137,3,579900
    1811,4,285900
    1437,3,249900
    1239,3,229900
    2132,4,345000
    4215,4,549000
    2162,4,287000
    1664,2,368500
    2238,3,329900
    2567,4,314000
    1200,3,299000
    852,2,179900
    1852,4,299900
    1203,3,239500
    

    作者:@臭咸鱼

    转载请注明出处:https://www.cnblogs.com/chouxianyu/

    欢迎讨论和交流!


  • 相关阅读:
    RC4加密
    树莓派3B+学习笔记:13、不间断会话服务screen
    树莓派3B+学习笔记:12、安装FireFox浏览器
    树莓派3B+学习笔记:11、查看硬件信息
    树莓派3B+学习笔记:10、使用SSH连接树莓派
    树莓派3B+学习笔记:9、更改软件源
    树莓派3B+学习笔记:8、安装MySQL
    树莓派3B+学习笔记:7、挂载exfat格式U盘和NTFS格式移动硬盘
    树莓派3B+学习笔记:6、安装TeamViewer
    树莓派3B+学习笔记:5、安装vim
  • 原文地址:https://www.cnblogs.com/chouxianyu/p/11704665.html
Copyright © 2020-2023  润新知