• 机器学习--鸢尾花分类任务


    1. 使用sklearn完成鸢尾花分类任务。

      相关知识

      为了完成本关任务,你需要掌握如何使用sklearn提供的DecisionTreeClassifier

      数据简介


      鸢尾花数据集是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(SetosaVersicolourVirginica)三个种类中的哪一类(其中分别用012代替)。

      数据集中部分数据与标签如下图所示:



      DecisionTreeClassifier

      DecisionTreeClassifier的构造函数中有两个常用的参数可以设置:

      • criterion:划分节点时用到的指标。有gini基尼系数),entropy(信息增益)。若不设置,默认为gini
      • max_depth:决策树的最大深度,如果发现模型已经出现过拟合,可以尝试将该参数调小。若不设置,默认为None

      sklearn中其他分类器一样,DecisionTreeClassifier类中的fit函数用于训练模型,fit函数有两个向量输入:

      • X:大小为[样本数量,特征数量]ndarray,存放训练样本;

      • Y:值为整型,大小为[样本数量]ndarray,存放训练样本的分类标签。

      DecisionTreeClassifier类中的predict函数用于预测,返回预测标签,predict函数有一个向量输入:

      • X:大小为[样本数量,特征数量]ndarray,存放预测样本。

      DecisionTreeClassifier的使用代码如下:

      1. from sklearn.tree import DecisionTreeClassifier
      2. clf = tree.DecisionTreeClassifier()
      3. clf.fit(X_train, Y_train)
      4. result = clf.predict(X_test)

      编程要求

      补充python代码,实现鸢尾花数据的分类任务,其中训练集数据保存在./step7/train_data.csv中,训练集标签保存在。./step7/train_label.csv中,测试集数据保存在。./step7/test_data.csv中。请将对测试集的预测结果保存至。./step7/predict.csv中。这些csv文件可以使用pandas读取与写入。

      注意:当使用pandas读取完csv文件后,请将读取到的DataFrame转换成ndarray类型。这样才能正常的使用fitpredict

      示例代码:

      1. import pandas as pd
      2. # as_matrix()可以将DataFrame转换成ndarray
      3. # 此时train_df的类型为ndarray而不是DataFrame
      4. train_df = pd.read_csv('train_data.csv').as_matrix()

      数据文件格式如下图所示:


      标签文件格式如下图所示:


      PS:predict.csv文件的格式必须与标签文件格式一致。

      测试说明

      只需将结果保存至./step7/predict.csv即可,程序内部会检测您的代码,预测准确率高于0.95视为过关。

    代码一


    import pandas as pd
    from sklearn.tree import DecisionTreeClassifier
    train_df = pd.read_csv('./step7/train_data.csv').as_matrix()
    train_label = pd.read_csv('./step7/train_label.csv').as_matrix()
    test_df = pd.read_csv('./step7/test_data.csv').as_matrix()
    dt = DecisionTreeClassifier()
    dt.fit(train_df, train_label)
    result = dt.predict(test_df)
    result = pd.DataFrame({'target':result})
    result.to_csv('./step7/predict.csv', index=False)
     

    代码二

    #********* Begin *********#
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    # from sklearn.tree import DecisionTreeClassifier, export_graphviz
    import numpy as np
    import pandas as pd

    #获取训练数据
    train_data = pd.read_csv('./step7/train_data.csv')
    #获取训练标签
    train_label = pd.read_csv('./step7/train_label.csv')
    train_label = train_label['target']
    #获取测试数据
    test_data = pd.read_csv('./step7/test_data.csv').as_matrix()
    # train_df = pd.read_csv('train_data.csv').as_matrix()
    #训练模型
    # as_matrix()可以将DataFrame转换成ndarray
    # 此时train_df的类型为ndarray而不是DataFrame

    clf = DecisionTreeClassifier()
    clf.fit(train_data,train_label)
    #获取预测标签
    predict = clf.predict(test_data)
    #将预测标签写入csv
    df = pd.DataFrame({'target':predict})
    df.to_csv("./step7/predict.csv",index=False)

    #********* End *********#
  • 相关阅读:
    Redis --- 客户端 --- Another Redis Desktop Manager
    Docker --- 记录安装与使用过程中遇到的问题
    Docker安装教程
    Python --- pip --- No module named 'pip'异常问题
    天气接口测试用例生成报告
    jsonpath使用
    python小知识,字典
    python小知识,列表推导式
    python小知识,sort和serted的区别
    如何查看app启动的activity
  • 原文地址:https://www.cnblogs.com/xueshadouhui/p/12627580.html
Copyright © 2020-2023  润新知