-
需求:对数据进行分类问题的处理
-
开发步骤:
- 1-准备SparkSession的环境
- 2-准备大数据的数据
- 3-读取数据并进行解析
- 4-数据的基本信息的查看
- 5-特征工程
- 6-准备算法
- 7-模型训练
- 8-模型预测
- 9-模型校验
- 10-模型保存
- 11-新数据预测
-
代码模板:
import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/**
* DESC: 对分类问题的模板的代码
* Complete data processing and modeling process steps:
*- 1-准备SparkSession的环境
*- 2-准备大数据的数据
*- 3-读取数据并进行解析
*- 4-数据的基本信息的查看
*- 5-特征工程
*- 6-准备算法
*- 7-模型训练
*- 8-模型预测
*- 9-模型校验
*- 10-模型保存
*- 11-新数据预测
*
*/
object ClassficationModelTest {
var datapath = "D:\BigData\Workspace\SparkMachineLearningTest\SparkMllib_BigData32\src\main\resources\iris.csv"
def main(args: Array[String]): Unit = {
// - 1-准备SparkSession的环境
val conf: SparkConf = new SparkConf().setAppName("ClassficationModelTest").setMaster("local[*]")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
spark.sparkContext.setLogLevel("WARN")
// - 2-准备大数据的数据
val irisDF: DataFrame = spark.read.format("csv")
.option("header", true)
.option("inferschema", true)
.option("sep", ",")
.load(datapath)
// - 3-读取数据并进行解析
irisDF.show(10, false)
// +------------+-----------+------------+-----------+-----------+
// |sepal_length|sepal_width|petal_length|petal_width|class |
// +------------+-----------+------------+-----------+-----------+
// |5.1 |3.5 |1.4 |0.2 |Iris-setosa|
// |4.9 |3.0 |1.4 |0.2 |Iris-setosa|
// |4.7 |3.2 |1.3 |0.2 |Iris-setosa|
// |4.6 |3.1 |1.5 |0.2 |Iris-setosa|
// - 4-数据的基本信息的查看
irisDF.printSchema()
// 因为在写各种string类型数据的时候可能会有一些单词拼写错误,可以实现定义
val sepal_length_feeature = "sepal_length"
val sepal_width_feeature = "sepal_width"
val petal_length_feeature = "petal_length"
val petal_width_feeature = "petal_width"
val class_label = "class"
// root
// |-- sepal_length: double (nullable = true)
// |-- sepal_ double (nullable = true)
// |-- petal_length: double (nullable = true)
// |-- petal_ double (nullable = true)
// |-- class: string (nullable = true)
// - 5-特征工程
//5-1处理类别型的数据class
val stringIndexer: StringIndexer = new StringIndexer()
.setInputCol(class_label)
.setOutputCol("classlabel")
val stringIndexerModel: StringIndexerModel = stringIndexer.fit(irisDF)
val indexDF: DataFrame = stringIndexerModel.transform(irisDF)
//5-2处理分散的特征整合为特征向量
val vectorAssembler: VectorAssembler = new VectorAssembler()
.setInputCols(Array(sepal_length_feeature, sepal_width_feeature, petal_length_feeature, petal_width_feeature))
.setOutputCol("features")
val vecDF: DataFrame = vectorAssembler.transform(indexDF)
//5-3VectorIndexer对类别值的索引化,加速构建决策树
val vectorIndexer: VectorIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("vecindexFeatures")
.setMaxCategories(20)
val vectorIndexerModel: VectorIndexerModel = vectorIndexer.fit(vecDF)
val vecindexerDF: DataFrame = vectorIndexerModel.transform(vecDF)
vecindexerDF.show(10, false)
// - 6-准备算法
val classifier: DecisionTreeClassifier = new DecisionTreeClassifier()
.setLabelCol("classlabel")
.setPredictionCol("prces")
.setFeaturesCol("vecindexFeatures")
.setMaxDepth(5)
.setImpurity("gini")
val Array(trainingSet, testSet): Array[Dataset[Row]] = vecindexerDF.randomSplit(Array(0.8, 0.2), seed = 1234L)
// - 7-模型训练
val model: DecisionTreeClassificationModel = classifier.fit(trainingSet)
// - 8-模型预测
val y_pred_train: DataFrame = model.transform(trainingSet)
val y_pred_test: DataFrame = model.transform(testSet)
y_pred_train.show(10, false)
// - 9-模型校验
val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
//"(f1|weightedPrecision|weightedRecall|accuracy)"
.setMetricName("accuracy")
.setPredictionCol("prces")
.setLabelCol("classlabel")
val acc_test: Double = evaluator.evaluate(y_pred_test)
val acc_train: Double = evaluator.evaluate(y_pred_train)
println("acc in trainset score is:", acc_train)
println("acc in testset score is:", acc_test)
// (acc in trainset score is:,0.9920634920634921)
// (acc in testset score is:,0.9583333333333334)
// // - 10-模型保存
// val datapath="D:\BigData\Workspace\SparkMachineLearningTest\SparkMllib_BigData32\src\main\resources\model1"
// model.save(datapath)
// // - 11-新数据预测
// DecisionTreeClassificationModel.load(datapath)
}
}