• 线性回归——最小二乘法拟合


    记录下用最小二乘法拟合线性模型的代码实现:

    1、应用正规方程(Normal Equation)求解最小二乘法举例:

    最简单的y关于x的线性方程(y=eta_0+eta_1x)

    预测值和观察值:

     写成矩阵形式:

    (Xeta=y),其中(X=egin{bmatrix} 1& x_1\ 1& x_2\ ...&...\ 1& x_n\ end{bmatrix}) , (eta=egin{bmatrix} eta_0\ eta_1\ end{bmatrix}) , (y=egin{bmatrix} y_1\ y_2\ ...\ y_n\ end{bmatrix})

    最小二乘法的解(eta)可以通过解正规方程获得: (X^TXeta=X^Ty)

    python代码实现如下:

    import numpy as np
    from numpy.linalg import inv
    
    
    x=[0.015348106072082663,0.021715879765738008,0.0316253067336889,
       0.03212431406271639,0.03828189026158509,0.03965578961513808,
       0.05502733389871053,0.06116957740576664,0.06170785924013203,
       0.07206404835977503]
    y=[22, 0, 22, 11, 9, 31, 20, 31, 2, 20]
    
    
    def plotLinearRegression1():
        X = np.array(x).reshape(len(x),1)
        reg=linear_model.LinearRegression(fit_intercept=True,normalize=False)
        reg.fit(X,y)
        k=reg.coef_#获取斜率w1,w2,w3,...,wn
        b=reg.intercept_#获取截距w0
        #x0=np.arange(0,10,0.2)
        x0 = np.array(x).reshape(len(x),)
        y0=k*x0+b
        plt.scatter(x,y)
        plt.plot(x0, y0)
        print('k',k);
        print('b',b)
      
          
    '''
    用least squares
    '''
    def plotLinearRegression2():
        X = np.vstack([np.ones((1,len(x))),x]).T
        print('X:
    ',X)
        Xt = X.T
        Z = Xt.dot(X)
        invZ = inv(Z)
        print('invZ:
    ',invZ)
        
        W = (invZ.dot(Xt)).dot(y)
        print('W:
    ',W)
        
        #plot
        plt.scatter(x,y)
        x0 = np.array(x).reshape(len(x),)
       # x0=np.arange(0,10,0.2)
        plt.plot(x0, W[1]*x0+W[0])
        
        
    def plotLinearRegression3():
        A = np.vstack([x, np.ones(len(x))]).T
        print('A:
    ',A)
        m, c = np.linalg.lstsq(A, y, rcond=None)[0]
        print('m: ',m,'  c:',c)
       
        x0 = np.array(x).reshape(len(x),)
        plt.plot(x, y, 'o', label='Original data', markersize=10)
        plt.plot(x0, m*x0 + c, 'r', label='Fitted line')
        plt.legend()
        plt.show()  
        
            
    if __name__ == '__main__':
        plotLinearRegression2()

     

  • 相关阅读:
    SQLSERVER 的表分区(水平) 操作记录2
    GraphQl in ASP.NET Core
    初始认知学习 .net core 逐步加深
    C# 关于使用JavaScriptSerializer 序列化与返序列化的操作
    Nginx、IIS 相关命令
    SqlServer:查询指定表所有外键关联表信息
    centos 重启宝塔命令
    c# 根据日志中的方法信息,反射再次执行相关方法
    jackson 下载地址记录
    【设计模式】六大原则
  • 原文地址:https://www.cnblogs.com/davidxu/p/14394963.html
Copyright © 2020-2023  润新知