• [Python]基于K-Nearest Neighbors[K-NN]算法的鸢尾花分类问题解决方案


      看了原理,总觉得需要用具体问题实现一下机器学习算法的模型,才算学习深刻。而写此博文的目的是,网上关于K-NN解决此问题的博文很多,但大都是调用Python高级库实现,尤其不利于初级学习者本人对模型的理解和工程实践能力的提升也不利于Python初学者实现该模型。

      本博文的特点:

        一 全面性地总结K-NN模型的特征、用途

        二  基于Python的内置模块,不调用任何第三方库实现

      博文主要分为四部分:

        基本模型(便于理清概念、回顾模型)

        对待解决问题的重述

        模型(算法)和评价(一来,以便了解模型特点,为以后举一反三地应用作铺垫;二来,有利于以后快速复习)、

        编程实现(Code)。

      特别声明:

        1.劳动成果开源,未经同意博主(千千寰宇:http://cnblogs.com/johnnyzen),不得以任何形式转载、复制。

        2.如有纰漏或者其他看法,欢迎共同探讨~

    零 基本模型

      (本部分内容,均来源于引用[1],其原理讲解十分通俗易懂)

      ①K-近邻算法,即K-Nearest Neighbor algorithm,简称K-NN算法。单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1时,算法便成了最近邻算法,即寻找最近的那个邻居。

      ②所谓K-NN算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也就是K个邻居), 这K个实例的多数属于某个类,就把该输入实例分类到这个类中。

      ③实例

        猜猜看:有一个未知形状(绿色圆点),如何判断其是什么形状?

        问题:给这个绿色的圆分类?

        对噪声数据过于敏感。为了解决这个问题,我们可以把位置样本周边的多个最近样本计算在内,扩大参与决策的样本量,以避免个别数据直接决定决策结果。

        有两类不同的样本数据,分别用蓝色的小正方形和红色的小三角形表示,而图正中间的那个绿色的圆所标示的数据则是待分类的数据。

        如果K=3,判定绿色的这个待分类点属于红色的三角形一类。

        如果K=5,判定绿色的这个待分类点属于蓝色的正方形一类。

    一 问题

      题目:ML之k-NN:k-NN实现对150朵共三种花的实例的萼片长度、宽,花瓣长、宽数据统计,根据一朵新花的四个特征来预测其种类
      数据源:https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
      数据源说明:https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names

    二 解决过程及模型评价

      一 审题/准备数据集
        1.1 明确问题基本模型,及涉及要素(特征值、有无标记、可考虑的基本算法模型):K-NN、分类、监督学习、有标记、鸢尾花分类
        1.2 准备数据集及其处理方法
          数据源:https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
          数据源说明:https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names
      二 分析并设计解决思路(算法步骤)
        2.0 零 样本及样本空间、测试空间的程序表示
            设计样本的class类:Iris(鸢尾花类):
              便于将样本空间以数组化表示 + 便于后期训练对各样本的取值、标记、比对等操作
        2.1 一 读取数据集,装载样本集
            提取文件数据至计算机进程(程序)中,并恰当地表示样本、样本集、测试集
          文件I/O操作:
            open(filepath+filename,mode)、
            file.readline()、
            file.readlines()、
            file.next()、
          字符串处理:
            str.split():切割
            float(str)、int(str):数值转换
          注意:对数据集文件中无关数据的处理(跳行、文件转进制、跳列等过滤)
        2.2 二 训练测试样本
          算法描述:[以鸢尾花分类为例]
            遍历测试样本
              计算测试样本与已标记的样本的欧式距离
            对各欧式距离升序排序
            选择前K项的已样本作为一子集( 即 选择最近的K项邻居作为参照标准)
              遍历统计,已标记子集的花朵种类何种花朵数目种类最多
            设置当前测试样本的预测花朵种类为该种
            结束。
          注释:花的种类分别为:Iris-setosa、Iris-versicolor、Iris-virginica;共计3种。
        2.3 三 计算预测准确率
          rate = Count(测试验证集中,已标记值 == 预测值 的样本) / 测试验证集总数
        2.4 四 模型检验(未做)
      三 编程实现
        (见第三部分示范代码)
      四 总结模型[K-NN]
        针对本问题
          0 样本数据集的样本顺序排列,选拔初始的参照样本空间时,需要先将数据集乱序排序,再做选拔;否则,分类准确率往往趋于0
            原因:对未纳入训练样本对应种类的待测样本,几乎无真实参照集合,所以将导致预测结果差得离谱(准确率趋于0)。
        涉及因素
          0 特征向量Features
            影响算法的处理效率:特征向量数越多,计算量越大,处理时间越长。
            影响算法的结果正确性:K值越多,可参照指标越多,结果越趋于正确。
          1 K:K个最近邻居元素的控制
            控制待预测样本的参照边界
            影响算法的结果正确性:
            K=1时,选拔最近的元素的标记值作为预测值
          2 Weight权值:对参照邻居的选拔标准
            标准:欧式距离、切比雪夫距离、马氏距离(这个,目前不清楚)、最近者赋予更大权值(增加其重要性)
            控制参照样本对待测样本的重要性
            有利于选拔最贴近真实预测结果、最符合的样本
          3 已预测并标记的样本是否加入参照样本集,影响后续待测样本预测?
          4 样本数量与种类
            标准:种类全面、数量按比例丰富;在上前提下,样本越多越好
              否则,尤其是种类覆盖不全面时,可能严重影响对偏未标记种类的样本预测效果
            样本种类↑,数量↑,预测准确率↑,预测处理效率↓
        总结性评价
          1 监督学习、分类问题
          2 基于实例数据的非参数学习算法
            Input:已标记数据集中的K个最近的训练样本组成
            Output:判别待测样本之类型
          3 可用于非线性分类
          4 样本种类↑,数量↑,特征向量↑:预测准确率↑,预测处理效率↓
          5 性能与完美的准确性不能绝对两全
            总体处理的复杂度:O(n)
          6 对异常值不敏感
          7 总体来说,准确度高,无人为因素。
          8 K值大小对结果的准确性要视具体情况而定,人为调参优化
          9 样本数量与种类:
            标准:种类全面、数量按比例丰富;在上前提下,样本越多越好
            否则,尤其是种类覆盖不全面时,可能严重影响对偏未标记种类的样本预测效果
          10   已预测的样本是否加入参照样本空间,将影响后续待测样本对预测结果起多大作用是一个可考虑的问题
          11 其他:
            需要提前知道样本空间之所有种类
            预测样本时,计算量大,[不利于实时预测],偏胖服务器模式
          12 模型常见实践:
            手写数字识别系统
            鸢尾花分类
            爱情片与动作片分类
            约会网站匹配
            对新贷款用户的还款情况预测
            ========================
            文本分类
    三 编程实现(For Python)
      工程文件分为三部分:
        __init__.py【main()启动函数、核心算法】
        Iris.py【设计数据结构(类)、模块(职责分离)】
        file_handle.py【数据提取、文件处理】

        由于前面第二部分已经详细叙述,且代码中注释已经十分详细,便不在对代码进行解释,阅读注释便容易懂。

        __init__.py

    import random;
    import math;
    import Iris; # 自定义
    import file_handle; # 自定义
    
    #if __name__ == '__main__': #__name__ == '__main__'是Python的main函数入口
    def main(print_test=False,print_samples=False):
        follows = []; # 样本集空间(前sampleAmount项)+测试集空间(后 sampleAll-sampleAmount项)
        data = "";    # 样本数据(前sampleAmount项)+测试集空间(后 sampleAll-sampleAmount项)
        sampleAll = 150;
        sampleAmount = 100; # 标记样本集数目(剩余的便作为测试集)
        k = 5;
        test_print = print_test; 
        
        ########## 一 读取数据集,装载样本集
        ##### 1.1 加载数据集数据
        data = file_handle.read("./dataset/data.txt",1,'r');#1:忽略第一行
        # print(data);
        list = data.split('
    ');
        i = 0;
        for line in list: # 如:line = "5.1,3.5,1.4,0.2,Iris-setosa"
            item = line.split(',');# 如:item = [5.1,3.5,1.4,0.2,Iris-setosa]
            label_species = item.pop();#移除最后一项:标记种类
            #print("[test] item:", item,"	label_species:", label_species); # test
            follows.append(Iris.Iris(item,label_species));
            #print("[ ",i," ] ",follows[i].toString());
            i += 1;
            pass;
        random.shuffle(follows); # 【千万注意!!!】由于原数据集是有序的,如果不做乱序处理,预测结果会及其不理想(准确率,趋近于0),当然,这也是这一模型的缺陷之一
        ##### 1.2 选择前100项 作为已标记样本集
        #i = 0;
        #for i in range(sampleAmount):
        #    follows[i].setPredictSpecies(follows[i].label_species);
        #    pass;
        
        ########## 二 训练测试样本
        ##### 2.1 对101 - 150 项的测试集进行训练/预测
        ## 算法描述:
        ##     遍历测试样本
        ##        计算测试样本与已标记的样本的欧式距离
        ##     对各欧式距离升序排序
        ##     选择前K项的已样本作为一子集( 即 选择最近的K项邻居作为参照标准)
        ##        遍历统计,已标记子集的花朵种类何种花朵数目种类最多
        ##        设置当前测试样本的预测花朵种类为该种
        ##     结束。  
        ## 注释:花的种类分别为:Iris-setosa、Iris-versicolor、Iris-virginica;共计3种。
        offset = 0; # 测试空间偏移量:目的是为了将通过偏移量,增大原已标记样本空间的样本数量 即 使已预测的测试样本加入参照样本空间。
        for x in range(sampleAmount,sampleAll):# x:测试样本下标
            weights = [];# 对各欧式距离(权值)的升序排序列表
            for y in range(0,sampleAmount+offset):
                result = (math.sqrt( + 
                                   math.pow(follows[y].features[0] - follows[x].features[0],2) + 
                                   math.pow(follows[y].features[1] - follows[x].features[1],2) + 
                                   math.pow(follows[y].features[2] - follows[x].features[2],2) + 
                                   math.pow(follows[y].features[3] - follows[x].features[3],2)), y);# 存储x,方便排序后定位花朵
                #print("[test] weights[x]:", result);
                weights.append(result);
                pass;
            weights.sort(key = lambda item:item[0]); # 以各元组内第一首项[欧氏距离]为键,默认升序排序
            if test_print:
                for m in range(len(weights)): # 输出预测权重
                    print("[test] weights[",m,"]:",weights[m],"	weights[",m,"][1] > ",weights[m][1],":",follows[weights[m][1]].toString());
            kinds_count = {"Iris-setosa":0,"Iris-versicolor":0,"Iris-virginica":0}; # 对已标记样本空间中各种花的数目统计作初始化
            for z in range(0,k): # 选择前K项的已样本作为一子集( 即 选择最近的K项邻居作为参照标准)
                if test_print:         
                    print("[test] 排名前",z+1,"项 follows[",z,"]:",follows[weights[z][1]].toString());
                label_species = follows[weights[z][1]].label_species;
                if(label_species == 'Iris-setosa'):
                    kinds_count["Iris-setosa"] += 1;
                elif label_species == 'Iris-versicolor':
                    kinds_count["Iris-versicolor"] += 1;
                elif label_species == 'Iris-virginica':
                    kinds_count['Iris-virginica'] += 1;
                else:
                    print("[ERROR:Unknown Species] follows[",weight[z][1],"]:",follows[weight[z][1]]);
                pass;
            result =  max(kinds_count.items(), key = lambda item:item[1]); # 取统计花类数字典中最大值对应的序列
            follows[x].predict_species = result[0]; # 标记预测种类
            if test_print:
                print("[test] 预测结果",result, " [follows[",x,"].predict_species]:", follows[x].predict_species); # test
            offset += 1;
            #for test in range(len(weights)): # 测试-输出距离权值结果
            #    print("[",test,"] weights:",weights[test][0],"	",follows[weights[test][1]].toString());
            #    pass;
            pass;
            
        ########## 三 计算预测准确率
        rate = 0.0;
        i = 0;
        for i in range(sampleAmount,len(follows)):
            if(follows[i].label_species == follows[i].predict_species):
                rate += 1;
            else:
                print("[预测错误样本] follow[",i,"]:",follows[i].toString());
            pass;
        pass;
        rate = rate / (sampleAll - sampleAmount);
        print("预测准确率:",rate);
         
        if print_samples:
            for i in range(0,len(follows)):
                print(follows[i].toString());
                pass;
        pass;
    
    main(False,True);

        Iris.py

    'Iris module [class] '
    
    __author__ = 'Johnny Zen'
    
    class Iris:
        """
        Iris花(类)
        
        [Demo]
        iris = Iris([5.1,3.5,1.4,0.2],"Iris-setosa");
        print(iris.toString());
        iris.setPredictSpecies('Iris-setosa');
        print(iris.toString());
        print(iris.label_species);
        =======================
        [features][5.1, 3.5, 1.4, 0.2]    [label-species]Iris-setosa    [predict-species]None
        [features][5.1, 3.5, 1.4, 0.2]    [label-species]Iris-setosa    [predict-species]Iris-setosa
        Iris-setosa
        """
        features = [];
        label_species = None; # 标记种类
        predict_species = None; # 预测种类
        def __init__(self,features,label_species=None):
            if type(features).__name__ == 'list':
                self.features = features;
            else:
                self.features = list(features); # 此list方法对list对象执行将产生错误
                pass;
            for x in range(len(self.features)): # 列表内元素字符串转实数
                self.features[x] = float(self.features[x]);
            self.label_species = label_species;
            pass;
        def setPredictSpecies(self,predict_species=None):#设置预测种类
            self.predict_species = predict_species;
            pass;
        def toString(self):#与一般函数定义不同,类方法必须包含参数 self[第一个参数]
            return "[features]" + str(self.features) + "	[label]" + str(self.label_species + "	[predict]" + str(self.predict_species));
            pass;
        pass;

        file_handle.py

    "file_handle module [function]:read(filepath,ignore=0,mode='r')"
    
    def read(filepath,ignore=0,mode='r'):
        try:
            file = open(filepath,mode);
            ## file_content = file.read();
            file_content = '';
            i = 0;
            for i in range(0,ignore):
                file.readline();
                ##print(i);
                ##print(file.readline()); 
            for line in file.readlines():
                file_content += line;
        finally:
            if file:
                file.close();
            #print(file_content);
            return file_content;
        pass;

    四 参考文献

      [1] K-NN和K-Means算法

  • 相关阅读:
    Java连载63-异常处理try...catch...、方法getMessageyu printStackTrace
    Python连载58-http协议简介
    Java连载62-使用throws关键字处理异常
    HTML连载57-相对定位和绝对定位
    Java连载61-异常的机制与分类
    Python连载57- 邮件头和主题、解析邮件
    Java连载60-类之间的六种关系
    [Java] 数据库编程JDBC
    [bug] MySQL-Front连接MySQL 8.0失败
    [bug]mysql: The server time zone value 'Öйú±ê׼ʱ¼ä' is unrecognized or represents more than one time zone
  • 原文地址:https://www.cnblogs.com/johnnyzen/p/9625535.html
Copyright © 2020-2023  润新知