• scikit-learn入门学习记录


    一加载示例数据集

    from sklearn import datasets
    
    iris = datasets.load_iris()
    digits = datasets.load_digits()

    数据集是一个类似字典的对象,它保存有关数据的所有数据和一些元数据。该数据存储在.data成员中,它是一个数组

    数字数据集存放在digits.data,数据如下,里面包含很多数字数据集的数据,一个列表即一个数字所有数据

    [[  0.   0.   5. ...,   0.   0.   0.]
     [  0.   0.   0. ...,  10.   0.   0.]
     [  0.   0.   0. ...,  16.   9.   0.]
     ..., 
     [  0.   0.   1. ...,   6.   0.   0.]
     [  0.   0.   2. ...,  12.   0.   0.]
     [  0.   0.  10. ...,  12.   1.   0.]]

    digits.target给出数字数据集的真实数据,即我们正在尝试学习的每个数字图像对应的数字,数据如下

    [0 1 2 ..., 8 9 8]

    digits.image[0],其实和digits.data[0]数据一样,只是转换成二维的矩阵,数据如下

    [[  0.   0.   5.  13.   9.   1.   0.   0.]
     [  0.   0.  13.  15.  10.  15.   5.   0.]
     [  0.   3.  15.   2.   0.  11.   8.   0.]
     [  0.   4.  12.   0.   0.   8.   8.   0.]
     [  0.   5.   8.   0.   0.   9.   8.   0.]
     [  0.   4.  11.   0.   1.  12.   7.   0.]
     [  0.   2.  14.   5.  10.  12.   0.   0.]
     [  0.   0.   6.  13.  10.   0.   0.   0.]]

    digits.data[0]和digits.image[0]对比

    [  0.   0.   5.  13.   9.   1.   0.   0.   0.   0.  13.  15.  10.  15.   5.
       0.   0.   3.  15.   2.   0.  11.   8.   0.   0.   4.  12.   0.   0.   8.
       8.   0.   0.   5.   8.   0.   0.   9.   8.   0.   0.   4.  11.   0.   1.
      12.   7.   0.   0.   2.  14.   5.  10.  12.   0.   0.   0.   0.   6.  13.
      10.   0.   0.   0.]
    [[  0.   0.   5.  13.   9.   1.   0.   0.]
     [  0.   0.  13.  15.  10.  15.   5.   0.]
     [  0.   3.  15.   2.   0.  11.   8.   0.]
     [  0.   4.  12.   0.   0.   8.   8.   0.]
     [  0.   5.   8.   0.   0.   9.   8.   0.]
     [  0.   4.  11.   0.   1.  12.   7.   0.]
     [  0.   2.  14.   5.  10.  12.   0.   0.]
     [  0.   0.   6.  13.  10.   0.   0.   0.]]

    在数字数据集的情况下,任务是给出图像来预测其表示的数字。我们给出了10个可能类(数字从零到九)中的每一个的样本,我们在其上拟合一个 估计器,以便能够预测 看不见的样本所属的类。

    在scikit-learn,分类的估计是实现方法的Python对象和。fit(X, y)predict(T)

    估计器的一个例子是sklearn.svm.SVC实现支持向量分类的类。估计器的构造函数作为模型的参数作为参数,但目前我们将把估计器视为黑盒子

    from sklearn import svm
    
    clf = svm.SVC(gamma=0.001, C=100.)

    在这个例子中,我们设置gamma手动的值。通过使用诸如网格搜索交叉验证等工具,可以自动找到参数的良好值。

    我们称之为我们的估计器实例clf,因为它是一个分类器。它现在必须适应模型,也就是说,它必须从模型中学习。这是通过将我们的训练集传递给该fit方法来完成的。作为一个训练集,让我们使用除最后一个数据集的所有图像。我们用[:-1]Python语法选择这个训练集,它产生一个包含除最后一个条目之外的所有数组的新数组digits.data

    clf.fit(digits.data[:-1], digits.target[:-1])
    现在,您可以预测新值,特别是可以向分类器询问digits数据集中最后一个图像的数字是什么,我们还没有用来对分类器进行训练:
    print(clf.predict(digits.data[-1:]))

    总结一下

    其实就是创建一个svm类的实例

    使用fit来将训练集传递给该实例,传入两个参数,数据以及真实值

    最后使用predict来对数据进行预估

    下面给个完整的实例

    import matplotlib.pyplot as plt
    from sklearn import datasets, svm, metrics
    
    digits = datasets.load_digits()
    images_and_labels = list(zip(digits.images, digits.target))
    for index, (image, label) in enumerate(images_and_labels[:4]):
        plt.subplot(2, 4, index+1)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Training: %s' %label)
    
    n_samples = len(digits.images)
    print('before:', digits.images)
    data = digits.images.reshape((n_samples, -1))
    
    classifier = svm.SVC(gamma=0.001)
    classifier.fit(data[:n_samples//2], digits.target[:n_samples//2])
    expected = digits.target[n_samples//2:]
    predicted = classifier.predict(data[n_samples//2:])
    
    print('Classification report for classifiler %s:
    %s
    ' %(classifier, metrics.classification_report(expected, predicted)))
    print('Confusion matrix:
    %s' %metrics.confusion_matrix(expected, predicted))
    
    images_and_predictions = list(zip(digits.images[n_samples//2:], predicted))
    for index, (image, prediction) in enumerate(images_and_predictions[:4]):
        plt.subplot(2,4, index+5)
        plt.axis('off')
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title('Prediction: %i' %prediction)
    plt.show()

    参考自http://cwiki.apachecn.org/pages/viewpage.action?pageId=10813673

    找到一个不错的繁体中文文档,解释的比较详细

    https://machine-learning-python.kspax.io/Classification/ex1_Recognizing_hand-written_digits.html

  • 相关阅读:
    分布式事务解决方案
    数据库和缓存双写一致性解析
    RabbitMQ 分区脑裂处理策略
    RabbitMQ实现延迟队列
    RabbitMQ高可用原理
    PyTorch Lightning工具学习
    【数学知识拾贝】模式识别所需要的线性代数知识总结
    【深度强化学习】1. 基础部分
    给内容打标签
    前端性能优化有哪些点
  • 原文地址:https://www.cnblogs.com/lgh344902118/p/7889280.html
Copyright © 2020-2023  润新知