• 3.2_k-近邻算法案例分析


     

    k-近邻算法案例分析

    本案例使用最著名的”鸢尾“数据集,该数据集曾经被Fisher用在经典论文中,目前作为教科书般的数据样本预存在Scikit-learn的工具包中。

    读入Iris数据集细节资料

    from sklearn.datasets import load_iris
    # 使用加载器读取数据并且存入变量iris
    iris = load_iris()
    
    # 查验数据规模
    iris.data.shape
    
    # 查看数据说明(这是一个好习惯)
    print iris.DESCR

    通过上述代码对数据的查验以及数据本身的描述,我们了解到Iris数据集共有150朵鸢尾数据样本,并且均匀分布在3个不同的亚种;每个数据样本有总共4个不同的关于花瓣、花萼的形状特征所描述。由于没有制定的测试集合,因此按照惯例,我们需要对数据进行随即分割,25%的样本用于测试,其余75%的样本用于模型的训练。

    由于不清楚数据集的排列是否随机,可能会有按照类别去进行依次排列,这样训练样本的不均衡的,所以我们需要分割数据,已经默认有随机采样的功能。

    对Iris数据集进行分割

    from sklearn.cross_validation import train_test_split
    X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.25,random_state=42)

    对特征数据进行标准化

    from sklearn.preprocessing import StandardScaler
    
    ss = StandardScaler()
    X_train = ss.fit_transform(X_train)
    X_test = ss.fit_transform(X_test)

    K近邻算法是非常直观的机器学习模型,我们可以发现K近邻算法没有参数训练过程,也就是说,我们没有通过任何学习算法分析训练数据,而只是根据测试样本训练数据的分布直接作出分类决策。因此,K近邻属于无参数模型中非常简单一种。

    from sklearn.datasets import load_iris
    from sklearn.cross_validation import train_test_split
    from sklearn.preprocessing import StandardScaler
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.metrics import classification_report
    from sklearn.model_selection import GridSearchCV
    
    def knniris():
        """
        鸢尾花分类
        :return: None
        """
    
        # 数据集获取和分割
        lr = load_iris()
    
        x_train, x_test, y_train, y_test = train_test_split(lr.data, lr.target, test_size=0.25)
    
        # 进行标准化
    
        std = StandardScaler()
    
        x_train = std.fit_transform(x_train)
        x_test = std.transform(x_test)
    
        # estimator流程
        knn = KNeighborsClassifier()
    
        # # 得出模型
        # knn.fit(x_train,y_train)
        #
        # # 进行预测或者得出精度
        # y_predict = knn.predict(x_test)
        #
        # # score = knn.score(x_test,y_test)
    
        # 通过网格搜索,n_neighbors为参数列表
        param = {"n_neighbors": [3, 5, 7]}
    
        gs = GridSearchCV(knn, param_grid=param, cv=10)
    
        # 建立模型
        gs.fit(x_train,y_train)
    
        # print(gs)
    
        # 预测数据
    
        print(gs.score(x_test,y_test))
    
        # 分类模型的精确率和召回率
    
        # print("每个类别的精确率与召回率:",classification_report(y_test, y_predict,target_names=lr.target_names))
    
        return None
    
    if __name__ == "__main__":
        knniris()

     

  • 相关阅读:
    js打印指定元素内容
    c# RedisHelper
    T4生成整理
    T4随记
    c# 文本超长截断
    mysql自动安装教程说明
    完全卸载mysql免安装版
    解决WebClient或HttpWebRequest首次连接缓慢问题
    c# 停靠窗体
    c#透明panel
  • 原文地址:https://www.cnblogs.com/alexzhang92/p/10070226.html
Copyright © 2020-2023  润新知