• 交叉验证_自动获取模型最优超参数


    package Spark_MLlib
    
    import org.apache.spark.ml.{Pipeline, PipelineModel}
    import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
    
    /**
      * 调参+模型选择
      */
    case class schema_source(features:Vector,label:String)
    object 交叉验证_调参_逻辑回归 {
        val spark=SparkSession.builder().master("local[2]").getOrCreate()
         import spark.implicits._
      def main(args: Array[String]): Unit = {
    
        val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo.txt")
                   .map(_.split(",")).map(x=>schema_source(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF()
        data.show()
        val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
        val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(data)
        val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3))
    
        val lr=new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(50)
        val labelConverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
        labelIndexer.labels.foreach(println)
        //机器学习工作流
        val lrPipeline=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,lr,labelConverter))
        //交叉验证需要的模型评估
        val evaluator=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
        //构造参数网格
         val paramGrid=new ParamGridBuilder().addGrid(lr.regParam,Array(0.01,0.3,0.8)).addGrid(lr.elasticNetParam,Array(0.3,0.9)).build()
        //构建机器学习工作流的交叉验证,定义验证模型,模型评估,参数网格,数据集的折叠数(交叉验证原理)
         val cv=new CrossValidator().setEstimator(lrPipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(3)
        //训练模型
        val cvModel=cv.fit(trainData)
        //测试数据
        val lrPrediction=cvModel.transform(testData)
        lrPrediction.show()
        val evaluator2=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
        val lrAccuracy=evaluator2.evaluate(lrPrediction)
        println("准确率为: "+lrAccuracy)
        println("错误率为: "+(1-lrAccuracy))
        //获取最优模型
        val bestModel=cvModel.bestModel.asInstanceOf[PipelineModel]
        val lrModel=bestModel.stages(2).asInstanceOf[LogisticRegressionModel]
        println("二项逻辑回归模型系数矩阵: "+lrModel.coefficientMatrix)
        println("二项逻辑回归模型的截距向量: "+lrModel.interceptVector)
        println("类的数量(标签可以使用的值): "+lrModel.numClasses)
        println("模型所接受的特征的数量: "+lrModel.numFeatures)
    
       println("所有参数的设置为: "+lrModel.explainParams())
       println("最优的regParam的值为: "+lrModel.explainParam(lrModel.regParam))
       println("最优的elasticNetParam的值为: "+lrModel.explainParam(lrModel.elasticNetParam))
      }
    
    }
    +-----------------+-----+
    |         features|label|
    +-----------------+-----+
    |[5.1,3.5,1.4,0.2]|soyo1|
    |[4.9,3.0,1.4,0.2]|soyo1|
    |[4.7,3.2,1.3,0.2]|soyo1|
    |[4.6,3.1,1.5,0.2]|soyo1|
    |[5.0,3.6,1.4,0.2]|soyo1|
    |[5.4,3.9,1.7,0.4]|soyo1|
    |[4.6,3.4,1.4,0.3]|soyo1|
    |[5.0,3.4,1.5,0.2]|soyo1|
    |[4.4,2.9,1.4,0.2]|soyo1|
    |[4.9,3.1,1.5,0.1]|soyo1|
    |[5.4,3.7,1.5,0.2]|soyo1|
    |[4.8,3.4,1.6,0.2]|soyo1|
    |[4.8,3.0,1.4,0.1]|soyo1|
    |[4.3,3.0,1.1,0.1]|soyo1|
    |[5.8,4.0,1.2,0.2]|soyo1|
    |[5.7,4.4,1.5,0.4]|soyo1|
    |[5.4,3.9,1.3,0.4]|soyo1|
    |[5.1,3.5,1.4,0.3]|soyo1|
    |[5.7,3.8,1.7,0.3]|soyo1|
    |[5.1,3.8,1.5,0.3]|soyo1|
    +-----------------+-----+
    only showing top 20 rows
    
    soyo2
    soyo1
    soyo3
    +-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
    |         features|label|indexedLabel|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
    +-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
    |[4.3,3.0,1.1,0.1]|soyo1|         1.0|[4.3,3.0,1.1,0.1]|[-0.2949197997435...|[0.00821657808181...|       1.0|         soyo1|
    |[4.4,2.9,1.4,0.2]|soyo1|         1.0|[4.4,2.9,1.4,0.2]|[-0.1436502505351...|[0.02310764702310...|       1.0|         soyo1|
    |[4.6,3.1,1.5,0.2]|soyo1|         1.0|[4.6,3.1,1.5,0.2]|[-0.1980725396328...|[0.01584026165726...|       1.0|         soyo1|
    |[4.8,3.0,1.4,0.1]|soyo1|         1.0|[4.8,3.0,1.4,0.1]|[-0.0360182992158...|[0.01909506488946...|       1.0|         soyo1|
    |[4.8,3.1,1.6,0.2]|soyo1|         1.0|[4.8,3.1,1.6,0.2]|[-0.0963956817735...|[0.02165865158723...|       1.0|         soyo1|
    |[4.8,3.4,1.6,0.2]|soyo1|         1.0|[4.8,3.4,1.6,0.2]|[-0.3305444022091...|[0.00764403083532...|       1.0|         soyo1|
    |[4.9,2.4,3.3,1.0]|soyo2|         0.0|[4.9,2.4,3.3,1.0]|[0.64687664475266...|[0.83588965920895...|       0.0|         soyo2|
    |[4.9,3.0,1.4,0.2]|soyo1|         1.0|[4.9,3.0,1.4,0.2]|[0.00894554123863...|[0.02696343238302...|       1.0|         soyo1|
    |[5.0,3.5,1.6,0.6]|soyo1|         1.0|[5.0,3.5,1.6,0.6]|[-0.3209967599706...|[0.01781564148264...|       1.0|         soyo1|
    |[5.0,3.6,1.4,0.2]|soyo1|         1.0|[5.0,3.6,1.4,0.2]|[-0.4132228265822...|[0.00370148550004...|       1.0|         soyo1|
    |[5.1,3.7,1.5,0.4]|soyo1|         1.0|[5.1,3.7,1.5,0.4]|[-0.4380550804437...|[0.00533390253840...|       1.0|         soyo1|
    |[5.1,3.8,1.9,0.4]|soyo1|         1.0|[5.1,3.8,1.9,0.4]|[-0.4784298068885...|[0.00593236888116...|       1.0|         soyo1|
    |[5.2,2.7,3.9,1.4]|soyo2|         0.0|[5.2,2.7,3.9,1.4]|[0.60296648363520...|[0.65499655703255...|       0.0|         soyo2|
    |[5.2,3.5,1.5,0.2]|soyo1|         1.0|[5.2,3.5,1.5,0.2]|[-0.2334963952443...|[0.00721300202565...|       1.0|         soyo1|
    |[5.3,3.7,1.5,0.2]|soyo1|         1.0|[5.3,3.7,1.5,0.2]|[-0.3434664691509...|[0.00396451436269...|       1.0|         soyo1|
    |[5.4,3.4,1.5,0.4]|soyo1|         1.0|[5.4,3.4,1.5,0.4]|[-0.0655191408567...|[0.02050202848213...|       1.0|         soyo1|
    |[5.4,3.4,1.7,0.2]|soyo1|         1.0|[5.4,3.4,1.7,0.2]|[-0.0443512521479...|[0.01568504280438...|       1.0|         soyo1|
    |[5.4,3.9,1.3,0.4]|soyo1|         1.0|[5.4,3.9,1.3,0.4]|[-0.4746044317663...|[0.00285607924154...|       1.0|         soyo1|
    |[5.4,3.9,1.7,0.4]|soyo1|         1.0|[5.4,3.9,1.7,0.4]|[-0.4369295847326...|[0.00451151133277...|       1.0|         soyo1|
    |[5.5,2.3,4.0,1.3]|soyo2|         0.0|[5.5,2.3,4.0,1.3]|[1.06413594105520...|[0.51327715648015...|       0.0|         soyo2|
    +-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
    only showing top 20 rows
    
    准确率为: 0.9418343292582645
    错误率为: 0.05816567074173551
    二项逻辑回归模型系数矩阵: 0.4612907305046201    -0.7804957347855317  0.09418711758439907  -0.011652325959556013  
    -0.559055378870932    2.7385209747134933   -1.052922922424876   -2.5223769474140303    
    -0.07629895224519458  -3.6867236615320547  1.0014498171011217   4.581938360185545      
    二项逻辑回归模型的截距向量: [-0.039423333303658874,0.0972586768296292,-0.05783534352597033]
    类的数量(标签可以使用的值): 3
    模型所接受的特征的数量: 4
    所有参数的设置为: aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
    elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.9)
    family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
    featuresCol: features column name (default: features, current: indexedFeatures)
    fitIntercept: whether to fit an intercept term (default: true)
    labelCol: label column name (default: label, current: indexedLabel)
    lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
    lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
    maxIter: maximum number of iterations (>= 0) (default: 100, current: 50)
    predictionCol: prediction column name (default: prediction)
    probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
    rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
    regParam: regularization parameter (>= 0) (default: 0.0, current: 0.01)
    standardization: whether to standardize the training features before fitting the model (default: true)
    threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
    thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
    tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
    upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
    upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
    weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)

    最优的regParam的值为: regParam: regularization parameter (
    >= 0) (default: 0.0, current: 0.01) 最优的elasticNetParam的值为: elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.9)
  • 相关阅读:
    SpringBoot第五篇:整合Mybatis
    SpringBoot第四篇:整合JDBCTemplate
    SpringBoot第三篇:配置文件详解二
    分享一篇去年的项目总结
    Oracle生成多表触发器sql
    Oracle 设置用户密码永不过期
    Oracle建表提示SQL 错误: ORA-00904: : 标识符无效
    MySql数据备份
    ETL全量多表同步简述
    ETL全量单表同步简述
  • 原文地址:https://www.cnblogs.com/soyo/p/7826771.html
Copyright © 2020-2023  润新知