Refer to the DecisionTree
Python docs and DecisionTreeModel
Python docs for more details on the API.
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
impurity='gini', maxDepth=5, maxBins=32)
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
Find full example code at "examples/src/main/python/mllib/decision_tree_classification_example.py" in the Spark repo.
class pyspark.mllib.tree.DecisionTree[source]
Learning algorithm for a decision tree model for classification or regression.
New in version 1.1.0.
- classmethod trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity='gini', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)[source]
-
Train a decision tree model for classification.
Parameters: - data – Training data: RDD of LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.
- numClasses – Number of classes for classification.
- categoricalFeaturesInfo – Map storing arity of categorical features. An entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
- impurity – Criterion used for information gain calculation. Supported values: “gini” or “entropy”. (default: “gini”)
- maxDepth – Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (default: 5)
- maxBins – Number of bins used for finding splits at each node. (default: 32)
- minInstancesPerNode – Minimum number of instances required at child nodes to create the parent split. (default: 1)
- minInfoGain – Minimum info gain required to create a split. (default: 0.0)
Returns: DecisionTreeModel.
Example usage:
>>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> >>> data = [ ... LabeledPoint(0.0, [0.0]), ... LabeledPoint(1.0, [1.0]), ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) >>> print(model) DecisionTreeModel classifier of depth 1 with 3 nodes
>>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.0) Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 >>> model.predict(array([1.0])) 1.0 >>> model.predict(array([0.0])) 0.0 >>> rdd = sc.parallelize([[1.0], [0.0]]) >>> model.predict(rdd).collect() [1.0, 0.0]
-
摘自:https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree