• 我的spark python 决策树实例


    from numpy import array
    from pyspark.mllib.regression import LabeledPoint
    from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
    from pyspark import SparkContext
    from pyspark.mllib.evaluation import BinaryClassificationMetrics
    
    sc = SparkContext(appName="PythonDecisionTreeClassificationExample")
    data = [
         LabeledPoint(0.0, [0.0]),
         LabeledPoint(1.0, [1.0]),
         LabeledPoint(0.0, [-2.0]),
         LabeledPoint(0.0, [-1.0]),
         LabeledPoint(0.0, [-3.0]),
         LabeledPoint(1.0, [4.0]),
         LabeledPoint(1.0, [4.5]),
         LabeledPoint(1.0, [4.9]),
         LabeledPoint(1.0, [3.0])
     ]
    all_data = sc.parallelize(data) 
    (trainingData, testData) = all_data.randomSplit([0.8, 0.2])
    
    # model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
    model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                             impurity='gini', maxDepth=5, maxBins=32)
    print(model)
    print(model.toDebugString())
    model.predict(array([1.0]))
    model.predict(array([0.0]))
    rdd = sc.parallelize([[1.0], [0.0]])
    model.predict(rdd).collect()
    
    predictions = model.predict(testData.map(lambda x: x.features))
    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

      predictionsAndLabels = predictions.zip(testData.map(lambda lp: lp.label))

    metrics = BinaryClassificationMetrics(predictionsAndLabels )
    print "AUC=%f PR=%f" % (metrics.areaUnderROC, metrics.areaUnderPR)
    
    testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
    print('Test Error = ' + str(testErr))
    print('Learned classification tree model:')
    print(model.toDebugString())
    
    # Save and load model
    model.save(sc, "./myDecisionTreeClassificationModel")
    sameModel = DecisionTreeModel.load(sc, "./myDecisionTreeClassificationModel")
  • 相关阅读:
    Java 密钥库 证书 公钥 私钥
    Theos小例子
    armbian禁用zram
    常见JS混淆器和特征
    命令行工具收藏
    python中生成器的两段代码
    把mysql数据库从windows迁移到linux系统上的方法
    【转载】使用Flink低级处理函数ProcessFunction
    spark读取压缩文件
    SpringBoot系列——validation参数校验
  • 原文地址:https://www.cnblogs.com/bonelee/p/7151341.html
Copyright © 2020-2023  润新知