• sklearn逻辑回归实战


    题目要求

    根据学生两门课的成绩和是否入学的数据,预测学生能否顺利入学:利用ex2data1.txtex2data2.txt中的数据,进行逻辑回归和预测。

    数据放在最后边。

    ex2data1.txt处理

    作散点图可知,决策大致符合线性关系,但还是有弯曲(非线性),用线性效果并不好,因此可用两种方案:方案一,无多项式特征;方案二,有多项式特征。

    方案一:无多项式特征

    对ex2data1.txt中的数据进行逻辑回归,无多项式特征

    代码实现如下:

    """
        对ex2data1.txt中的数据进行逻辑回归(无多项式特征)
    """
    from sklearn.model_selection import train_test_split
    from matplotlib.colors import ListedColormap
    from sklearn.linear_model import LogisticRegression
    import numpy as np
    import matplotlib.pyplot as plt
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    # 数据格式:成绩1,成绩2,是否被录取(1代表被录取,0代表未被录取)
    
    
    # 函数(画决策边界)定义
    def plot_decision_boundary(model, axis):
        x0, x1 = np.meshgrid(
            np.linspace(axis[0], axis[1], int((axis[1] - axis[0]) * 100)).reshape(-1, 1),
            np.linspace(axis[2], axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1, 1),
        )
        X_new = np.c_[x0.ravel(), x1.ravel()]
    
        y_predict = model.predict(X_new)
        zz = y_predict.reshape(x0.shape)
    
        custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9'])
    
        plt.contourf(x0, x1, zz, cmap=custom_cmap)
    
    # 读取数据
    data = np.loadtxt('ex2data1.txt', delimiter=',')
    data_X = data[:, 0:2]
    data_y = data[:, 2]
    
    # 数据分割
    X_train, X_test, y_train, y_test = train_test_split(data_X, data_y, random_state=666)
    
    # 训练模型
    log_reg = LogisticRegression()
    log_reg.fit(X_train, y_train)
    
    # 结果可视化
    plot_decision_boundary(log_reg, axis=[0, 100, 0, 100])
    plt.scatter(data_X[data_y == 0, 0], data_X[data_y == 0, 1], color='red')
    plt.scatter(data_X[data_y == 1, 0], data_X[data_y == 1, 1], color='blue')
    plt.xlabel('成绩1')
    plt.ylabel('成绩2')
    plt.title('两门课程成绩与是否录取的关系')
    plt.show()
    
    # 模型测试
    print(log_reg.score(X_train, y_train))
    print(log_reg.score(X_test, y_test))
    
    

    输出结果如下:

    0.8533333333333334
    0.76
    

    ex2data1逻辑回归(无多项式).png

    方案二:引入多项式特征

    对ex2data1.txt中的数据进行逻辑回归,引入多项式特征。经调试,当degree为3时,耗费时间较长;当degree为2时,耗费时间可接受,效果与方案一相比好了很多

    实现如下:

    """
        对ex2data1.txt中的数据进行逻辑回归(引入多项式特征)
    """
    from sklearn.model_selection import train_test_split
    from matplotlib.colors import ListedColormap
    from sklearn.linear_model import LogisticRegression
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.preprocessing import PolynomialFeatures
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import StandardScaler
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    # 数据格式:成绩1,成绩2,是否被录取(1代表被录取,0代表未被录取)
    
    
    # 函数定义
    def plot_decision_boundary(model, axis):
        x0, x1 = np.meshgrid(
            np.linspace(axis[0], axis[1], int((axis[1] - axis[0]) * 100)).reshape(-1, 1),
            np.linspace(axis[2], axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1, 1),
        )
        X_new = np.c_[x0.ravel(), x1.ravel()]
    
        y_predict = model.predict(X_new)
        zz = y_predict.reshape(x0.shape)
    
        custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9'])
    
        plt.contourf(x0, x1, zz, cmap=custom_cmap)
    
    
    def PolynomialLogisticRegression(degree):
        return Pipeline([
            ('poly', PolynomialFeatures(degree=degree)),
            ('std_scaler', StandardScaler()),
            ('log_reg', LogisticRegression())
        ])
    
    
    # 读取数据
    data = np.loadtxt('ex2data1.txt', delimiter=',')
    data_X = data[:, 0:2]
    data_y = data[:, 2]
    
    # 数据分割
    X_train, X_test, y_train, y_test = train_test_split(data_X, data_y, random_state=666)
    
    # 训练模型
    poly_log_reg = PolynomialLogisticRegression(degree=2)
    poly_log_reg.fit(X_train, y_train)
    
    # 结果可视化
    plot_decision_boundary(poly_log_reg, axis=[0, 100, 0, 100])
    plt.scatter(data_X[data_y == 0, 0], data_X[data_y == 0, 1], color='red')
    plt.scatter(data_X[data_y == 1, 0], data_X[data_y == 1, 1], color='blue')
    plt.xlabel('成绩1')
    plt.ylabel('成绩2')
    plt.title('两门课程成绩与是否录取的关系')
    plt.show()
    
    # 模型测试
    print(poly_log_reg.score(X_train, y_train))
    print(poly_log_reg.score(X_test, y_test))
    

    输出如下

    0.92
    0.92
    

    ex2data1逻辑回归(有多项式).png

    ex2data2.txt处理

    作散点图可知,这组数据的决策边界绝对是非线性的,所以直接引入多项式特征对ex2data2.txt中的数据进行逻辑回归。

    代码实现如下:

    """
        对ex2data2.txt中的数据进行逻辑回归(引入多项式特征)
    """
    from sklearn.model_selection import train_test_split
    from matplotlib.colors import ListedColormap
    from sklearn.linear_model import LogisticRegression
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.preprocessing import PolynomialFeatures
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import StandardScaler
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    # 数据格式:成绩1,成绩2,是否被录取(1代表被录取,0代表未被录取)
    
    
    # 函数定义
    def plot_decision_boundary(model, axis):
        x0, x1 = np.meshgrid(
            np.linspace(axis[0], axis[1], int((axis[1] - axis[0]) * 100)).reshape(-1, 1),
            np.linspace(axis[2], axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1, 1),
        )
        X_new = np.c_[x0.ravel(), x1.ravel()]
    
        y_predict = model.predict(X_new)
        zz = y_predict.reshape(x0.shape)
    
        custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9'])
    
        plt.contourf(x0, x1, zz, cmap=custom_cmap)
    
    
    def PolynomialLogisticRegression(degree):
        return Pipeline([
            ('poly', PolynomialFeatures(degree=degree)),
            ('std_scaler', StandardScaler()),
            ('log_reg', LogisticRegression())
        ])
    
    
    # 读取数据
    data = np.loadtxt('ex2data2.txt', delimiter=',')
    data_X = data[:, 0:2]
    data_y = data[:, 2]
    
    # 数据分割
    X_train, X_test, y_train, y_test = train_test_split(data_X, data_y, random_state=666)
    
    # 训练模型
    poly_log_reg = PolynomialLogisticRegression(degree=2)
    poly_log_reg.fit(X_train, y_train)
    
    # 结果可视化
    plot_decision_boundary(poly_log_reg, axis=[-1, 1, -1, 1])
    plt.scatter(data_X[data_y == 0, 0], data_X[data_y == 0, 1], color='red')
    plt.scatter(data_X[data_y == 1, 0], data_X[data_y == 1, 1], color='blue')
    plt.xlabel('成绩1')
    plt.ylabel('成绩2')
    plt.title('两门课程成绩与是否录取的关系')
    plt.show()
    
    # 模型测试
    print(poly_log_reg.score(X_train, y_train))
    print(poly_log_reg.score(X_test, y_test))
    
    

    输出结果如下:

    由图可知,分类结果较好。

    0.7954545454545454
    0.9
    

    ex2data2逻辑回归(有多项式).png

    两份数据

    ex2data1.txt

    34.62365962451697,78.0246928153624,0
    30.28671076822607,43.89499752400101,0
    35.84740876993872,72.90219802708364,0
    60.18259938620976,86.30855209546826,1
    79.0327360507101,75.3443764369103,1
    45.08327747668339,56.3163717815305,0
    61.10666453684766,96.51142588489624,1
    75.02474556738889,46.55401354116538,1
    76.09878670226257,87.42056971926803,1
    84.43281996120035,43.53339331072109,1
    95.86155507093572,38.22527805795094,0
    75.01365838958247,30.60326323428011,0
    82.30705337399482,76.48196330235604,1
    69.36458875970939,97.71869196188608,1
    39.53833914367223,76.03681085115882,0
    53.9710521485623,89.20735013750205,1
    69.07014406283025,52.74046973016765,1
    67.94685547711617,46.67857410673128,0
    70.66150955499435,92.92713789364831,1
    76.97878372747498,47.57596364975532,1
    67.37202754570876,42.83843832029179,0
    89.67677575072079,65.79936592745237,1
    50.534788289883,48.85581152764205,0
    34.21206097786789,44.20952859866288,0
    77.9240914545704,68.9723599933059,1
    62.27101367004632,69.95445795447587,1
    80.1901807509566,44.82162893218353,1
    93.114388797442,38.80067033713209,0
    61.83020602312595,50.25610789244621,0
    38.78580379679423,64.99568095539578,0
    61.379289447425,72.80788731317097,1
    85.40451939411645,57.05198397627122,1
    52.10797973193984,63.12762376881715,0
    52.04540476831827,69.43286012045222,1
    40.23689373545111,71.16774802184875,0
    54.63510555424817,52.21388588061123,0
    33.91550010906887,98.86943574220611,0
    64.17698887494485,80.90806058670817,1
    74.78925295941542,41.57341522824434,0
    34.1836400264419,75.2377203360134,0
    83.90239366249155,56.30804621605327,1
    51.54772026906181,46.85629026349976,0
    94.44336776917852,65.56892160559052,1
    82.36875375713919,40.61825515970618,0
    51.04775177128865,45.82270145776001,0
    62.22267576120188,52.06099194836679,0
    77.19303492601364,70.45820000180959,1
    97.77159928000232,86.7278223300282,1
    62.07306379667647,96.76882412413983,1
    91.56497449807442,88.69629254546599,1
    79.94481794066932,74.16311935043758,1
    99.2725269292572,60.99903099844988,1
    90.54671411399852,43.39060180650027,1
    34.52451385320009,60.39634245837173,0
    50.2864961189907,49.80453881323059,0
    49.58667721632031,59.80895099453265,0
    97.64563396007767,68.86157272420604,1
    32.57720016809309,95.59854761387875,0
    74.24869136721598,69.82457122657193,1
    71.79646205863379,78.45356224515052,1
    75.3956114656803,85.75993667331619,1
    35.28611281526193,47.02051394723416,0
    56.25381749711624,39.26147251058019,0
    30.05882244669796,49.59297386723685,0
    44.66826172480893,66.45008614558913,0
    66.56089447242954,41.09209807936973,0
    40.45755098375164,97.53518548909936,1
    49.07256321908844,51.88321182073966,0
    80.27957401466998,92.11606081344084,1
    66.74671856944039,60.99139402740988,1
    32.72283304060323,43.30717306430063,0
    64.0393204150601,78.03168802018232,1
    72.34649422579923,96.22759296761404,1
    60.45788573918959,73.09499809758037,1
    58.84095621726802,75.85844831279042,1
    99.82785779692128,72.36925193383885,1
    47.26426910848174,88.47586499559782,1
    50.45815980285988,75.80985952982456,1
    60.45555629271532,42.50840943572217,0
    82.22666157785568,42.71987853716458,0
    88.9138964166533,69.80378889835472,1
    94.83450672430196,45.69430680250754,1
    67.31925746917527,66.58935317747915,1
    57.23870631569862,59.51428198012956,1
    80.36675600171273,90.96014789746954,1
    68.46852178591112,85.59430710452014,1
    42.0754545384731,78.84478600148043,0
    75.47770200533905,90.42453899753964,1
    78.63542434898018,96.64742716885644,1
    52.34800398794107,60.76950525602592,0
    94.09433112516793,77.15910509073893,1
    90.44855097096364,87.50879176484702,1
    55.48216114069585,35.57070347228866,0
    74.49269241843041,84.84513684930135,1
    89.84580670720979,45.35828361091658,1
    83.48916274498238,48.38028579728175,1
    42.2617008099817,87.10385094025457,1
    99.31500880510394,68.77540947206617,1
    55.34001756003703,64.9319380069486,1
    74.77589300092767,89.52981289513276,1
    

    ex2data2.txt

    0.051267,0.69956,1
    -0.092742,0.68494,1
    -0.21371,0.69225,1
    -0.375,0.50219,1
    -0.51325,0.46564,1
    -0.52477,0.2098,1
    -0.39804,0.034357,1
    -0.30588,-0.19225,1
    0.016705,-0.40424,1
    0.13191,-0.51389,1
    0.38537,-0.56506,1
    0.52938,-0.5212,1
    0.63882,-0.24342,1
    0.73675,-0.18494,1
    0.54666,0.48757,1
    0.322,0.5826,1
    0.16647,0.53874,1
    -0.046659,0.81652,1
    -0.17339,0.69956,1
    -0.47869,0.63377,1
    -0.60541,0.59722,1
    -0.62846,0.33406,1
    -0.59389,0.005117,1
    -0.42108,-0.27266,1
    -0.11578,-0.39693,1
    0.20104,-0.60161,1
    0.46601,-0.53582,1
    0.67339,-0.53582,1
    -0.13882,0.54605,1
    -0.29435,0.77997,1
    -0.26555,0.96272,1
    -0.16187,0.8019,1
    -0.17339,0.64839,1
    -0.28283,0.47295,1
    -0.36348,0.31213,1
    -0.30012,0.027047,1
    -0.23675,-0.21418,1
    -0.06394,-0.18494,1
    0.062788,-0.16301,1
    0.22984,-0.41155,1
    0.2932,-0.2288,1
    0.48329,-0.18494,1
    0.64459,-0.14108,1
    0.46025,0.012427,1
    0.6273,0.15863,1
    0.57546,0.26827,1
    0.72523,0.44371,1
    0.22408,0.52412,1
    0.44297,0.67032,1
    0.322,0.69225,1
    0.13767,0.57529,1
    -0.0063364,0.39985,1
    -0.092742,0.55336,1
    -0.20795,0.35599,1
    -0.20795,0.17325,1
    -0.43836,0.21711,1
    -0.21947,-0.016813,1
    -0.13882,-0.27266,1
    0.18376,0.93348,0
    0.22408,0.77997,0
    0.29896,0.61915,0
    0.50634,0.75804,0
    0.61578,0.7288,0
    0.60426,0.59722,0
    0.76555,0.50219,0
    0.92684,0.3633,0
    0.82316,0.27558,0
    0.96141,0.085526,0
    0.93836,0.012427,0
    0.86348,-0.082602,0
    0.89804,-0.20687,0
    0.85196,-0.36769,0
    0.82892,-0.5212,0
    0.79435,-0.55775,0
    0.59274,-0.7405,0
    0.51786,-0.5943,0
    0.46601,-0.41886,0
    0.35081,-0.57968,0
    0.28744,-0.76974,0
    0.085829,-0.75512,0
    0.14919,-0.57968,0
    -0.13306,-0.4481,0
    -0.40956,-0.41155,0
    -0.39228,-0.25804,0
    -0.74366,-0.25804,0
    -0.69758,0.041667,0
    -0.75518,0.2902,0
    -0.69758,0.68494,0
    -0.4038,0.70687,0
    -0.38076,0.91886,0
    -0.50749,0.90424,0
    -0.54781,0.70687,0
    0.10311,0.77997,0
    0.057028,0.91886,0
    -0.10426,0.99196,0
    -0.081221,1.1089,0
    0.28744,1.087,0
    0.39689,0.82383,0
    0.63882,0.88962,0
    0.82316,0.66301,0
    0.67339,0.64108,0
    1.0709,0.10015,0
    -0.046659,-0.57968,0
    -0.23675,-0.63816,0
    -0.15035,-0.36769,0
    -0.49021,-0.3019,0
    -0.46717,-0.13377,0
    -0.28859,-0.060673,0
    -0.61118,-0.067982,0
    -0.66302,-0.21418,0
    -0.59965,-0.41886,0
    -0.72638,-0.082602,0
    -0.83007,0.31213,0
    -0.72062,0.53874,0
    -0.59389,0.49488,0
    -0.48445,0.99927,0
    -0.0063364,0.99927,0
    0.63265,-0.030612,0
    

    作者:@臭咸鱼

    转载请注明出处:https://www.cnblogs.com/chouxianyu/

    欢迎讨论和交流!


  • 相关阅读:
    Java并发编程实战3-可见性与volatile关键字
    Java并发编程实战2-线程安全
    Java词频统计
    Java中split的对象被特殊字符(.或|)分隔
    Kernel Space与User Space(转)
    [Python]Pandas简单入门(转)
    Python的自省机制
    谈谈Python中元类Metaclass(二):ORM实践
    Python的垃圾回收机制
    [Python]Pip换源以及设置代理
  • 原文地址:https://www.cnblogs.com/chouxianyu/p/11717456.html
Copyright © 2020-2023  润新知