• 《机器学习实战(基于scikit-learn和TensorFlow)》第三章内容的学习心得


    本章主要讲关于分类的一些机器学习知识点。我会按照以下关键点来总结自己的学习心得:(本文源码在文末,请自行获取)

    • 什么是MNIST数据集
    • 二分类
    • 二分类的性能评估与权衡
    • 从二元分类到多类别分类
    • 错误分析
    • 多标签分类、多输出分类

     

    什么是MNIST数据集

    MNIST数据集是一组由美国高中生和人口调查局员工手写的70,000个数字图片数据集。官方链接为:http://yann.lecun.com/exdb/mnist/

    这组数据集X标签是28*28大小的像素强度数值,y标签是一个该图像对应的一个真实数字。

    我们通过sklearn提供的函数可以对该数据集进行下载:

    image

    这个fetch_mldata会将名字为MNIST original的数据集通过官方库中的数据集下载下来,返回的是一个dict对象。(dict对象介绍

    一般童鞋应该是下不下来,文章底部的链接中有对应数据供大家直接下载。

    大家下载完毕后,执行下方代码:

    from sklearn.datasets.base import get_data_home
    print (get_data_home())

    查找到sklearn对于你的机器上的数据集缓存地址,将下载的文件中DataSets中的mnist-original.mat文件直接复制到显示位置即可。

     

    现在假定,我们数据集已经处理完毕。

    每次我们针对数据集的观察是必不可少的操作,因此先看看返回的dict对象中都有什么对我们是有好处的:

    image

    通过加载sklearn的数据集,通常包括:DESCR(描述数据集)、data(包含一个数组,每个实例为一行,每个特征为一列,即我们的x)、target(包含一个带有标记的数组,即我们的y)

    加载其中某些内容:

    image

    看出X是一个70000*784的矩阵,也能得出我们可以训练的实例有7万个,每个实例都可以表示成为一个28*28的矩阵(784开根号为28),对应是一个图像。y是一个标签,相应也有7万个。

    我们将其中某一个数字画出来,可以更加直观的表示:

    image

    matplotlib上一章心得说了一些使用介绍,这里不详细讲了。这里只说matplotlib中的imshow()函数。

    imshow()函数功能就是针对提供的像素点,生成一张2维图片。该函数的参数非常多,具体请看链接

    从显示的情况看,这个图像看上去更像是数字5,验证一下:

    image

    猜的没错,是5。

    在深入研究该数据集之前,我们首先应该为以后的分类算法划分测试与训练集。

    这里需要说明的是,MNIST数据集已经帮助我们分好了一个数据集划分,前60,000个数据是训练集,后10,000为测试集,所以我们直接划分即可:

    image

    但是需要注意的是,这里的每个集合里面的数据划分非常有规律,按照0-9的顺序排列,这对于我们的训练是存在问题的,我们应该随机将集合重新洗牌:numpy中的random类中的permutation函数可以达到重新洗牌的目的:

    image

    具体使用请看链接

     

    二分类

    如果我们只需要检测其中的某一个数字,那么我们可以将“识别某个数字与否”这个问题看成是一个二分类问题。假设我们这里需要识别数字5,那么最后识别出的结果就只有两个:5或非5。我们相当于需要构建一个数字5的检测器。

    首先,对该任务创建目标向量(其实就是我们所说的y,这里所做的操作是一个方便以后的处理,不做这样的处理也是可以的),这个是用来标识识别结果是否正确,标签为5的表示为true,其他为False:

    image

    构建好目标向量之后,选择一个分类算法构建我们的分类器,这里选择随机梯度下降分类器(SGDClassifier):

    image

    从之前所有引入的估算器,不管是决策树、随机森林、还是线性回归、逻辑回归还是这里的随即梯度下降,sklearn针对这些模型的初始化都非常的类似,都是首先导入,然后初始化该估算器类的构造函数,这里判断我们是否需要针对算法的某些参数进行修改,如果采用默认则直接无参初始化,如果需要,则传入需要修改的参数值,之后采用fit()函数,传入训练集的x与y,进行训练,最后得到一个训练有素的分类器。

    然后我们通过predict()函数进行预测:

    image

    由于some_digit对应的数字之前我们发现是5,因此预测结果为Ture。

     

    二分类的性能评估与权衡

    首先,采用交叉验证测量精度。该方法的含义是使用将训练集分成K份,每次用K-1份进行训练,剩下的1份做测试,每次算出一个精确度并输出,我们首先自己编写:

    image

    还可以用cross_val_score()函数来实现:

    image

    image

    我们可以发现,每次的精确度都在90%以上,最后一次竟然能到96%?我的模型这么好吗?!为了让我们放下心来,我们可以做如下操作,设置一个非常笨的分类器,直接预测所有数字都不是5,我们来看看精度:

    image

    image

    image

    看见了吗,我们认为的最笨的分类器,竟然都能达到90%的概率!有问题,绝对有问题!

    我想,你已经明白了,这是因为我们的数据中,每个数字大概占总数的10%左右,因此猜一张图不是5的概率都可以达到90%!因此,我们采用精确度这个指标是存在问题的。特别是,当我们处理偏斜数据集(某类型数据非常多,数据分布不平衡的数据集)时,精确度这个性能指标是绝对不可以的。

    我们应该采用混淆矩阵的方法!

    混淆矩阵,总体思路就是构成一个矩阵,这个矩阵记录统计出A类别的实例被错误的分成B类别实例的次数。具体看下图我进行解释:

    image(图片源自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839

    我们来看这张图,我们想看本来是猫但被误认为狗的情况有多少次,从矩阵中,我们就找真实值为猫,预测值为狗的对应行列即可,找到可以看出是3次,这就是混淆矩阵。

    我们再升华一下这个矩阵,从理论的高度解释一下:

    image(图片源自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839

    这里的右下角的每个白格子,都代表一种预测情况。其实对应多分类的问题,也可以转换为二分类问题,只需要将待分类的类别归为一类,剩下归为一类即可。

    好了,这里的TP、FP、FN、TN要解释一下编码的构成:预测是否错误+预测的类别:

    • TP(True Positive):预测正确,预测的类别为正例(即需要预测的那个类别)
    • FP(False Postive):预测错误,预测类别为正例
    • FN(False Negative):预测错误,预测类别为负例
    • TN(True Negative):预测正确,预测类别为负例

    先说这么多,我们先把概念记住,先上一波代码,构成我们需要的混淆矩阵吧。

    要计算混淆矩阵,需要一组预测才能比较,但测试集永远记住都是放在项目启动后再进行,因此我们还是采用交叉验证的方式进行,这次用cross_val_predict()函数:

    image

    image

    这里注明一下,再sklearn中生成的混淆矩阵中,行表示实际类别,列表示预测类别,与之前给出的那个矩阵行列正好相反,不过我们只要能分出TP、TN、FP、FN就好。

    从这个混淆矩阵中,我们可以发现53517张被正确预测为“非5”的类别,4236张被正确预测为“5”的类别,1185张被错误的分为“非5”的类别,1062张被错误的分为“5”类别。

    最优的混淆矩阵应该是只存在T开头的属性,其他情况的数字应该是0。

    混淆矩阵可以给我们提供大量的信息,但如果指标如果能更简介些,我们可以更加直观的了解一些最需要的信息,接下来介绍精度、召回率的概念。

    首先,给出精度的概念:

    image

    该公式计算的是正确被分为正类占预测器全部预测为正类的数量的比率。为什么叫精度呢?我们可以这么想,预测正类有这么多,其中预测的对的数量又有这么多,这不就是计算了一个精度嘛?

    接下来,我们谈谈召回率:

    image

    该公式计算的是正确被分为正类的数量占总正类数量(因为FN对应这些情况的真实值应该是正类呀)的比率。为什么叫召回率呢?我们可以举个例子来思考这个问题。比如车吧,车厂在出售车之前,需要对车进行一个检测,我们这里定义的正类就是“存在问题的车”。当然存在有些车本身存在问题,但没检测出来就售出了,然后呢车厂发起召回问题车,这时候我们就需要召回率来帮忙了!当然,我们可想而知,召回率太高,估计老板得气死。。-。-言归正传,召回率也叫查全率,就是计算的是总共的正类中,我一共能正确分出多少的正类的比率。

    上代码,计算吧:

    image

    发现没,这个分类器没有之前那么亮眼了,我就知道!!哼!!Who me?

    我们还有一个指标可以结合上述两个指标,那就是,,,,,F1分数!

    上公式:

    image

    怎么理解F1分数呢?F1分数是精度和召回率的谐波平均值。正常的平均值平等对待所有值,而谐波平均值会给予较低的值更高的权重。因此当召回率和精度都很高时,分类器才能得到较高的F1分数,因此F1分数越高,能说明,我们的系统更加稳健。

    但是,我们不应该把追求F1分数作为我们的最终目标,我们要看实际的要求权衡召回率和精度,举个例子让大家明白:

    假设,需要训练一个分类器来检测儿童可以放心观看的视频,我们应该本着宁可错杀100不可放过一个的目的,可能会拦截好多好视频(低召回率),但确保保留下的视频都是安全的(高精度)。相反,如果需要训练一个分类器通过图像监控检测小偷,这个分类器应该本着不管这个人是否是不是小偷,当他做类似小偷的行为时,就应该发出警报。所以我们当然希望的是尽可能抓住更多的小偷咯,因此我们要求召回率要达到99%以上(可能会误报很多次,但几乎窃贼都在劫难逃!),但这样的话,我们的精度是会下降的。

    接下来解释一下为何精度和召回率不能兼得:

    让我们想一个问题,在极端的情况下,假设你在查找一个问题的答案,如果要求精度非常高,那么返回的结果就会很少,但都是你要的,如果要求召回率(查全率)很高,那么返回的结果很多,但其中有很多的结果是你不需要的。因此,需要高精度就会导致低召回率,反之亦然。

    返回到我们的数字分类的问题上,我们看一下SGDClassifier如何进行分类决策的。

    这个分类器,对于每个实例,它会基于决策函数计算一个分值,如果该值大于阈值,则判定为正类,否则为负类。放个图,解释一下:

    image

    假设,我们阈值在中间箭头位置,在阈值的右侧可以找到4个真正类(四个5),一个假正类(一个6)。因此,在该阈值下,精度为80%(4/5),召回率为67%(4/6)。当我们提高阈值,假正类的6就会变成负类,那么精度会提升到100%(3/3),但一个真正类变成了一个假负类,召回率变为50%(3/6)。

    sklearn不允许直接设置阈值,但是可以访问它用于预测的决策分数,我们可以基于分数,使用阈值预测,上代码:

    image

    我们如何选择阈值呢?我们先获取训练集中所有实例的分数:

    image

    image

    再使用precision_recall_curve()函数计算所有可能的阈值的精度和召回率,最后绘制精度和召回率相对于阈值的函数图像:

    image

    还有一种与二元分类器一起使用的工具:受试者工作特征曲线(ROC),该曲线绘制的是召回率和假正类率(FPR:被错误分为正类的负类实例比率,等于1-真负类率[被正确分类为弗雷德负类实例比率,也称为特异度])。

    使用roc_curve()函数计算多种阈值的TPR和FPR,然后绘制FPR对于TPR的曲线:

    image

    同样,召回率越高,分类器产生的假正类越多,虚线表示纯随机分类器的ROC曲线,一个优秀的分类器应该离这条线越远越好。有一个比较分类器的方法是测量曲线下面积(AUC):

    image

    那么,问题来了,如何选择指标呢?这里直接引用书上的原话:当正类非常少见或者你更关注假正类而不是假负类时,你应该选择PR曲线,反之ROC。

    训练一个随机森林的分类器,并比较SGD分类器:

    image

    image

    计算它的AUC得分:

    image

    它的精度和召回率如下:

    image

    image

    从二元分类到多分类

    我们的数字分类问题,可以分为10个二分类问题,在检测图片分类时,可以获取每个分类器的决策分数,哪个分高就决定时哪个数字。这是OvA策略,一对多。还有一种情况可以为每一对数字训练分类器,这称为OvO,一对一策略。

    有些算法在数据规模扩大时,表现糟糕,对于这类算法,一对一是优先选择,如果不是的话就一对多。

    sklearn会检测到你尝试使用二元分类算法进行多类别的分类任务,它会自动进行OvA:

    image

    image

    每个类别得出的分数以及分类器分出的类别我们都可以知道:

    image

    我们也可以进行一对一的分类策略,预测器的个数也能显示出来:

    image

    image

    至于评估分类器,提升准确率,这里不再赘述。只放代码,自己看就好:

    image

    image

    image

    image

    image

    image

    错误分析

     这里,假设我们已经找到了一个有潜力的模型,我们希望对该模型进行改进,我们可以分析其错误类型,帮助我们。

    首先可以看看混淆矩阵:

     我擦,数字有点多,看的不清楚,怎么才能更形象的表示呢?我们可以将该矩阵可视化:

    越白是表示数字越多,越黑表示数字越少。由于大多数白色都在对角线上,所以我们可以认为大部分的数字与图片可以正确分类。从局部上,我们可以发现,白色的部分也存在差异,数字2的白色就很好,而数字5、数字8看起来比其他的要差一些,可能的原因有:数据集中图片较少,或者数字5在执行效果上不如其他数字。

    我们把问题集中在错误上,将混淆矩阵中的每个值除以相应类别中的图片数,得出错误率:

    然后用0填充对角线,只保留错误率,重新绘制:

    目前,就可以看出分类器产生的错误种类了。记住行为实际类别,列为预测类别。我们可以发现,第6列第4行和第4列第5行两个格子很亮,说明分类器针对数字3、数字5容易混淆,其他的白色格子,我们可以进行同样的分析。

    下面,我们来具体看一下数字3和5分错的图片混淆矩阵。首先定义一个数字图像显示的函数:

    上述定义的函数大家直接拷贝即可,如果想深入研究其中的运行原理,请参考具体的文档进行查看,不过也不会难,只是需要一定的逻辑就好。

    然后对图片进行一个显示:

    上图中,左侧两个5*5的矩阵显示的是被分类为数字3的图,右侧两个5*5的矩阵显示了被分类为5的图。分类器弄错的数字就是左下方和右上方的图片。通过对比,我们可以发现有些数字用人脑来分辨也真的很容易分错,因此知道分类器在具体的分类问题上的差异后,我们就应该根据具体的问题,通过具体的手段进行修正。比如多采集数字3和数字5的数据,或者针对数字3和5的形状结构,开发新的特征来改进分类器,对图像预处理,或者采用更高级的算法等等来解决问题。

    多标签分类、多输出分类

    我们之前所作的分类器都会将一个实例分在一个类别中,而在某些情况下,需要分类器为每一个实例产出多个类别。比如分类照片中的人像,一张照片可能存在很多个人,分类器就可以输出照片中的人对应的姓名。这种输出多个二元标签的分类系统成为多标签分类系统。

    接下来,我们针对这个问题,创建一个多标签数组,这个数组包含数字图片的目标标签:第一个数字表示是否大数(7,8,9),另一个表示该数是否是奇数。我们这里需要注意的是,不是所有算法都能支持多标签的分类,我们这里选择KNeighborsClassifier分类器进行分类:

    我们会发现,最后输出的结果就是一个多标签的结果。

    最后,我们讨论一下多输出分类。

    这种类型的任务全称应该叫多输出—多类别分类,简单来说是多标签分类的泛化,其标签也可以是多种类别的。用一个例子来解释:我们现在需要针对含有噪声的图片进行去噪处理,就用我们的手写图像为例。我们构造的这个系统输出就是多个标签(一个像素点一个标签),每个标签可以存在多个值(0-255)。这个任务比起分类任务来说,更像回归任务,但是我们需要注意的是分类和回归的界限有时很模糊,我们其实不必要一定要分的很细,适合我们的目的就好,灵活掌握。

    来,最后上一波代码,说明上述的这个问题。当然,我们首先要添加噪声,然后编写图像显示的函数,最后对比一下输出就好:

    第四章的心得总结就到这里,我通过第四章的学习,将之前对召回率、精度、ROC曲线、F1分数等一系列概念有了更深的了解,希望通过自己的心得总结,能够帮助更多的人,谢谢。

    源码获取:

     1、我的Github:https://github.com/niufuquan1/MyStudy_For_sklearn_tensorflow/tree/master/ML_MNIST

    2、百度云:https://pan.baidu.com/s/1sX2ulOE7xgjJTzfVmA0N0Q (rxfv)

    前进,前进,不择手段地前进!
  • 相关阅读:
    MySQL四种分区类型
    CentOS下升级MySQL 5.0.* 到5.5
    CentOS5.5使用yum来安装LAMP
    mysql-bin 常见操作
    引爆你的Javascript代码进化
    python读写excel的简单方法demo
    python时间戳数字转为字符串格式表达
    Djang——CSRF verification failed. Request aborted
    Apache部署django
    Qt设置windows系统时间
  • 原文地址:https://www.cnblogs.com/nfuquan/p/10595243.html
Copyright © 2020-2023  润新知