-
使用
sklearn
完成鸢尾花分类任务。相关知识
为了完成本关任务,你需要掌握如何使用
sklearn
提供的DecisionTreeClassifier
。数据简介
鸢尾花数据集是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度
4
个属性预测鸢尾花卉属于(Setosa
,Versicolour
,Virginica
)三个种类中的哪一类(其中分别用0
,1
,2
代替)。数据集中部分数据与标签如下图所示:
DecisionTreeClassifier
DecisionTreeClassifier
的构造函数中有两个常用的参数可以设置:criterion
:划分节点时用到的指标。有gini
(基尼系数),entropy
(信息增益)。若不设置,默认为gini
max_depth
:决策树的最大深度,如果发现模型已经出现过拟合,可以尝试将该参数调小。若不设置,默认为None
和
sklearn
中其他分类器一样,DecisionTreeClassifier
类中的fit
函数用于训练模型,fit
函数有两个向量输入:-
X
:大小为[样本数量,特征数量]
的ndarray
,存放训练样本; -
Y
:值为整型,大小为[样本数量]
的ndarray
,存放训练样本的分类标签。
DecisionTreeClassifier
类中的predict
函数用于预测,返回预测标签,predict
函数有一个向量输入:X
:大小为[样本数量,特征数量]
的ndarray
,存放预测样本。
DecisionTreeClassifier
的使用代码如下:from sklearn.tree import DecisionTreeClassifier
clf = tree.DecisionTreeClassifier()
clf.fit(X_train, Y_train)
result = clf.predict(X_test)
编程要求
补充
python
代码,实现鸢尾花数据的分类任务,其中训练集数据保存在./step7/train_data.csv
中,训练集标签保存在。./step7/train_label.csv
中,测试集数据保存在。./step7/test_data.csv
中。请将对测试集的预测结果保存至。./step7/predict.csv
中。这些csv
文件可以使用pandas
读取与写入。注意:当使用
pandas
读取完csv
文件后,请将读取到的DataFrame
转换成ndarray
类型。这样才能正常的使用fit
和predict
。示例代码:
import pandas as pd
# as_matrix()可以将DataFrame转换成ndarray
# 此时train_df的类型为ndarray而不是DataFrame
train_df = pd.read_csv('train_data.csv').as_matrix()
数据文件格式如下图所示:
标签文件格式如下图所示:
PS:
predict.csv
文件的格式必须与标签文件格式一致。测试说明
只需将结果保存至
./step7/predict.csv
即可,程序内部会检测您的代码,预测准确率高于0.95
视为过关。
代码一
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
train_label = pd.read_csv('./step7/train_label.csv').as_matrix()
test_df = pd.read_csv('./step7/test_data.csv').as_matrix()
dt.fit(train_df, train_label)
result = dt.predict(test_df)
result.to_csv('./step7/predict.csv', index=False)
代码二