一.背景
传统的线性回归算法用于拟合所有的数据,当数据量非常的大,特征之间的关联非常的复杂的时候,这个方法就不太现实。这个时候就可以采用对数据进行切片的方式,然后在对切片后的局部的数据进行线性回归,如果首次切片之后的数据还是不符合线性的要求,那么就继续执行切片。在这个过程中树结构和回归算法是非常有用的。而cart算法又是一种常用的树回归算法。
二.概念
树回归算法总的来说,就是找到一个最优的特征,然后根据该特征的最优的特征值进行二元划分子树,如果值大于给定特征值就走左子树,否则就走右子树。通过不断的迭代这个过程就能创建一棵回归树或者分类树。二者区别只在于输入的数据是连续型还是离散型,以及采用的选择最优的特征和特征值的准则。
三.python实现
1.获得输入数据
from numpy import * # load the data def loaddata(filename): fr = open(filename) datamat = [] for line in fr.readlines(): curline = line.strip().split(' ') appenline = map(float, curline) datamat.append(appenline) return datamat
2.将数据集根据选定的最好的特征和特征值进行划分
# split the dataset in feature with value def binarysplitdataset(dataset, feature, value): mat0 = dataset[nonzero(dataset[:, feature] > value)[0], :] mat1 = dataset[nonzero(dataset[:, feature] <= value)[0], :] return mat0, mat1
3.找到最好的划分数据集的特征和特征集(方法是通过简单的遍历每个特征和每个特征值)
# calculate the leafnode value def leafvalue(dataset): return mean(dataset[:, -1]) # calculate the leafnode error def leaferror(dataset): return var(dataset[:, -1])*shape(dataset)[0] # find the best feature and value to split def findbestsplit(dataset, leaftype = leafvalue, errtype = leaferror, op = (1, 4)): neederr = op[0] needsample = op[1] minerror = Inf m, n = shape(dataset) if len(set(dataset[:, -1].T.tolist()[0])) == 1: return None, leaftype(dataset) s = errtype(dataset) bestfeat = 0 bestval = 0 for i in range(n-1): for j in set(dataset[:, i].T.tolist()[0]): mat0, mat1 = binarysplitdataset(dataset, i, j); error = errtype(mat0) + errtype(mat1) if shape(mat0)[0]< needsample or shape(mat1)[0] < needsample: continue if error < minerror: minerror = error bestfeat = i bestval = j if (s - minerror) < neederr: return None, leaftype(dataset) mat0, mat1 = binarysplitdataset(dataset, bestfeat, bestval) if shape(mat0)[0] < needsample or shape(mat1)[0] < needsample: return None, leaftype(dataset) return bestfeat, bestval
4.创建出一个所需的树
# create the regression tree def createtree(dataset, leaftype = leafvalue, errtype = leaferror, op = (1, 4)): feat, val = findbestsplit(dataset, leaftype, errtype, op) print feat if feat is None: return val lft, rig = binarysplitdataset(dataset, feat, val) rettree = {} rettree['feat'] = feat rettree['val'] = val lefttree = createtree(lft, leaftype , errtype, op) righttree = createtree(rig, leaftype, errtype, op) rettree['lefttree'] = lefttree rettree['righttree'] = righttree return rettree
经过以上的四个步骤就可以得到所需的一棵回归树,但是运行的时候可能会出现runtime error,主要是因为mean函数的影响,但是没关系,还是可以得到正确答案。同时要注意的是,机器学习实战上的根据最好的特征和特征值划分数据集的代码是错的,因为它只能获取第一行的数据,要把最后面的[0]去掉。以上的代码已经将这一点改正过来了。
四.树剪枝
剪枝的目的就是通过降低树的复杂度进而避免过拟合,也就是一种正则化的手段。剪枝主要有两种手段一种是预剪枝,一种是后剪枝。