• 第十二篇:深入学习高级非线性回归算法 --- 树回归系列算法


    前言

           前文讨论的回归算法都是全局且针对线性问题的回归,即使是其中的局部加权线性回归法,也有其弊端(具体请参考前文)

           采用全局模型会导致模型非常的臃肿,因为需要计算所有的样本点,而且现实生活中很多样本都有大量的特征信息。

           另一方面,实际生活中更多的问题都是非线性问题。

           针对这些问题,有了树回归系列算法。

    回归树

           在先前决策树的学习中,构建树是采用的 ID3 算法。在回归领域,该算法就有个问题,就是派生子树是按照所有可能值来进行派生。

           因此 ID3 算法无法处理连续性数据。

           故可使用二元切分法,以某个特定值为界进行切分。在这种切分法下,子树个数小于等于2。

           除此之外,再修改择优原则香农熵 (因为数据变为连续型的了),便可将树构建成一棵可用于回归的树,这样一棵树便叫做回归树。

           构建回归树的伪代码:

    1 找到最佳的待切分特征:
    2     如果该节点不能再分,将此节点存为叶节点。
    3     执行二元切分
    4     左右子树分别递归调用此函数

           二元切分的伪代码:

    1 对每个特征:
    2     对每个特征值:
    3         将数据集切成两份
    4         计算切分误差
    5         如果当前误差小于最小误差,则更新最佳切分以及最小误差。

           特别说明, (并直接建立叶节点)有三种情况:
                  1. 特征值划分完毕
                  2. 划分子集太小
                  3. 划分后误差改进不大
           这几个操作被称做 "预剪枝"。
      下面给出一个完整的回归树的小程序:

      1 #!/usr/bin/env python
      2 # -*- coding:UTF-8 -*-
      3 
      4 '''
      5 Created on 20**-**-**
      6 
      7 @author: fangmeng
      8 '''
      9 
     10 from numpy import *
     11 
     12 def loadDataSet(fileName):
     13     '载入测试数据'
     14     
     15     dataMat = []
     16     fr = open(fileName)
     17     for line in fr.readlines():
     18         curLine = line.strip().split('	')
     19         # 所有元素转换为浮点类型(函数编程)
     20         fltLine = map(float,curLine)
     21         dataMat.append(fltLine)
     22     return dataMat
     23 
     24 #============================
     25 # 输入:
     26 #        dataSet: 待切分数据集
     27 #        feature: 切分特征序号
     28 #        value:    切分值
     29 # 输出:
     30 #        mat0,mat1: 切分结果
     31 #============================
     32 def binSplitDataSet(dataSet, feature, value):
     33     '切分数据集'
     34     
     35     mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
     36     mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
     37     return mat0,mat1
     38 
     39 #========================================
     40 # 输入:
     41 #        dataSet: 数据集
     42 # 输出:
     43 #        mean(dataSet[:,-1]): 均值(也就是叶节点的内容)
     44 #========================================
     45 def regLeaf(dataSet):
     46     '生成叶节点'
     47     
     48     return mean(dataSet[:,-1])
     49 
     50 #========================================
     51 # 输入:
     52 #        dataSet: 数据集
     53 # 输出:
     54 #        var(dataSet[:,-1]) * shape(dataSet)[0]: 平方误差
     55 #========================================
     56 def regErr(dataSet):
     57     '计算平方误差'
     58     
     59     return var(dataSet[:,-1]) * shape(dataSet)[0]
     60 
     61 #========================================
     62 # 输入:
     63 #        dataSet: 数据集
     64 #        leafType: 叶子节点生成器
     65 #        errType: 误差统计器
     66 #        ops: 相关参数
     67 # 输出:
     68 #        bestIndex: 最佳划分特征 
     69 #        bestValue: 最佳划分特征值
     70 #========================================
     71 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
     72     '选择最优划分'
     73     
     74     # 获得相关参数中的最大样本数和最小误差效果提升值
     75     tolS = ops[0]; 
     76     tolN = ops[1]
     77     
     78     # 如果所有样本点的值一致,那么直接建立叶子节点。
     79     if len(set(dataSet[:,-1].T.tolist()[0])) == 1: 
     80         return None, leafType(dataSet)
     81     
     82     m,n = shape(dataSet)
     83     # 当前误差
     84     S = errType(dataSet)
     85     # 最小误差
     86     bestS = inf; 
     87     # 最小误差对应的划分方式
     88     bestIndex = 0; 
     89     bestValue = 0
     90     
     91     # 对于所有特征
     92     for featIndex in range(n-1):
     93         # 对于某个特征的所有特征值
     94         for splitVal in set(dataSet[:,featIndex]):
     95             # 划分
     96             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
     97             # 如果划分后某个子集中的个数不达标
     98             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
     99             # 当前划分方式的误差
    100             newS = errType(mat0) + errType(mat1)
    101             # 如果这种划分方式的误差小于最小误差
    102             if newS < bestS: 
    103                 bestIndex = featIndex
    104                 bestValue = splitVal
    105                 bestS = newS
    106     
    107     # 如果当前划分方式还不如不划分时候的误差效果
    108     if (S - bestS) < tolS: 
    109         return None, leafType(dataSet)
    110     # 按照最优划分方式进行划分
    111     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    112     # 如果划分后某个子集中的个数不达标
    113     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
    114         return None, leafType(dataSet)
    115     
    116     return bestIndex,bestValue
    117 
    118 #========================================
    119 # 输入:
    120 #        dataSet: 数据集
    121 #        leafType: 叶子节点生成器
    122 #        errType: 误差统计器
    123 #        ops: 相关参数
    124 # 输出:
    125 #        retTree: 回归树
    126 #========================================
    127 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    128     '构建回归树'
    129     
    130     # 选择最佳划分方式
    131     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    132     # feat为None的时候无需划分返回叶子节点
    133     if feat == None: return val #if the splitting hit a stop condition return val
    134     
    135     # 递归调用构建函数并更新树
    136     retTree = {}
    137     retTree['spInd'] = feat
    138     retTree['spVal'] = val
    139     lSet, rSet = binSplitDataSet(dataSet, feat, val)
    140     retTree['left'] = createTree(lSet, leafType, errType, ops)
    141     retTree['right'] = createTree(rSet, leafType, errType, ops)
    142     
    143     return retTree  
    144 
    145 def test():
    146     '展示结果'
    147     
    148     # 载入数据
    149     myDat = loadDataSet('/home/fangmeng/ex0.txt')
    150     # 构建回归树
    151     myDat = mat(myDat)
    152     
    153     print createTree(myDat)
    154     
    155     
    156 if __name__ == '__main__':
    157     test()

           测试结果:

    回归树的优化工作 - 剪枝

           在上面的代码中,递归的条件中已经加入了重重的 "剪枝" 工作。

           这些在建树的时候的剪枝操作通常被成为预剪枝。这是很有很有必要的,经过预剪枝的树几乎就是没有预剪枝树的大小的百分之一甚至更小,而性能相差无几

           而在树建立完毕之后,基于训练集和测试集能做更多更高效的剪枝工作,这些工作叫做 "后剪枝"。

           可见,剪枝是一项较大的工作量,是对树非常关键的优化过程。

           后剪枝过程的伪代码如下:

    1 基于已有的树切分测试数据:
    2     如果存在任一子集是一棵树,则在该子集上递归该过程。
    3     计算将当前两个叶节点合并后的误差
    4     计算不合并的误差
    5     如果合并会降低误差,则将叶节点合并。

           具体实现函数如下:

     1 #===================================
     2 # 输入:
     3 #        obj: 判断对象
     4 # 输出:
     5 #        (type(obj).__name__=='dict'): 判断结果
     6 #===================================
     7 def isTree(obj):
     8     '判断对象是否为树类型'
     9     
    10     return (type(obj).__name__=='dict')
    11 
    12 #===================================
    13 # 输入:
    14 #        tree: 处理对象
    15 # 输出:
    16 #        (tree['left']+tree['right'])/2.0: 坍塌后的替代值
    17 #===================================
    18 def getMean(tree):
    19     '坍塌处理'
    20     
    21     if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    22     if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    23     
    24     return (tree['left']+tree['right'])/2.0
    25   
    26 #===================================
    27 # 输入:
    28 #        tree: 处理对象
    29 #        testData: 测试数据集
    30 # 输出:
    31 #        tree: 剪枝后的树
    32 #===================================  
    33 def prune(tree, testData):
    34     '后剪枝'
    35     
    36     # 无测试数据则坍塌此树
    37     if shape(testData)[0] == 0: 
    38         return getMean(tree)
    39     
    40     # 若左/右子集为树类型
    41     if (isTree(tree['right']) or isTree(tree['left'])):
    42         # 划分测试集
    43         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    44     # 在新树新测试集上递归进行剪枝
    45     if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    46     if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    47     
    48     # 如果两个子集都是叶子的话,则在进行误差评估后决定是否进行合并。
    49     if not isTree(tree['left']) and not isTree(tree['right']):
    50         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    51         errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +sum(power(rSet[:,-1] - tree['right'],2))
    52         treeMean = (tree['left']+tree['right'])/2.0
    53         errorMerge = sum(power(testData[:,-1] - treeMean,2))
    54         if errorMerge < errorNoMerge: 
    55             return treeMean
    56         else: return tree
    57     else: return tree

    模型树

           这也是一种很棒的树回归算法。

           该算法将所有的叶子节点不是表述成一个值,而是对叶子部分节点建立线性模型。比如可以是最小二乘法的基本线性回归模型。

           这样在叶子节点里存放的就是一组线性回归系数了。非叶子节点部分构造就和回归树一样。

           这个是上面建立回归树算法的函数头:

           createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):

           对于模型树,只需要修改修改 leafType(叶节点构造器) 和 errType(误差分析器) 的实现即可,分别对应如下modelLeaf 函数和 modelErr 函数:

     1 #=========================
     2 # 输入:
     3 #        dataSet: 测试集
     4 # 输出:
     5 #        ws,X,Y: 回归模型
     6 #=========================
     7 def linearSolve(dataSet):
     8     '辅助函数,用于构建线性回归模型。'
     9     
    10     m,n = shape(dataSet)
    11     X = mat(ones((m,n))); 
    12     Y = mat(ones((m,1)))
    13     X[:,1:n] = dataSet[:,0:n-1]; 
    14     Y = dataSet[:,-1]
    15     xTx = X.T*X
    16     if linalg.det(xTx) == 0.0:
    17         raise NameError('系数矩阵不可逆')
    18     ws = xTx.I * (X.T * Y)
    19     return ws,X,Y
    20 
    21 #=======================
    22 # 输入:
    23 #       dataSet: 数据集
    24 # 输出:
    25 #        ws: 回归系数
    26 #=======================
    27 def modelLeaf(dataSet):
    28     '叶节点构造器'
    29     
    30     ws,X,Y = linearSolve(dataSet)
    31     return ws
    32 
    33 #=======================================
    34 # 输入:
    35 #       dataSet: 数据集
    36 # 输出:
    37 #        sum(power(Y - yHat,2)): 平方误差
    38 #=======================================
    39 def modelErr(dataSet):
    40     '误差分析器'
    41     
    42     ws,X,Y = linearSolve(dataSet)
    43     yHat = X * ws
    44     return sum(power(Y - yHat,2))

    回归树 / 模型树的使用

           前面的工作主要介绍了两种树 - 回归树,模型树的构建,下面进一步学习如何利用这些树来进行预测。

           当然,本质也就是递归遍历树

           下为遍历代码,通过修改参数设置要使用并传递进来的是回归树还是模型树:

     1 #==============================
     2 # 输入:
     3 #       model: 叶子
     4 #       inDat: 测试数据
     5 # 输出:
     6 #        float(model): 叶子值
     7 #==============================
     8 def regTreeEval(model, inDat):
     9     '回归树预测'
    10     
    11     return float(model)
    12 
    13 #==============================
    14 # 输入:
    15 #       model: 叶子
    16 #       inDat: 测试数据
    17 # 输出:
    18 #        float(X*model): 叶子值
    19 #==============================
    20 def modelTreeEval(model, inDat):
    21     '模型树预测'
    22     n = shape(inDat)[1]
    23     X = mat(ones((1,n+1)))
    24     X[:,1:n+1]=inDat
    25     return float(X*model)
    26 
    27 #==============================
    28 # 输入:
    29 #        tree: 待遍历树
    30 #        inDat: 测试数据
    31 #        modelEval: 叶子值获取器
    32 # 输出:
    33 #        分类结果
    34 #==============================
    35 def treeForeCast(tree, inData, modelEval=regTreeEval):
    36     '使用回归/模型树进行预测 (modelEval参数指定)'
    37     
    38     # 如果非树类型,返回值。
    39     if not isTree(tree): return modelEval(tree, inData)
    40     
    41     # 左遍历
    42     if inData[tree['spInd']] > tree['spVal']:
    43         if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
    44         else: return modelEval(tree['left'], inData)
    45         
    46     # 右遍历
    47     else:
    48         if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
    49         else: return modelEval(tree['right'], inData)

           使用方法非常简单,将树和要分类的样本传递进去就可以了。如果是模型树就将分类函数 treeForeCast 的第三个参数改为modelTreeEval即可。

           这里就不再演示实验具体过程了。

    小结

           1. 选择哪个回归方法,得看哪个方法的相关系数高。(可使用 corrcoef 函数计算)

           2. 树的回归和分类算法其实本质上都属于贪心算法,不断去寻找局部最优解。

           3. 关于回归的讨论就先告一段落,接下来将进入到无监督学习部分。

  • 相关阅读:
    10 个雷人的注释,就怕你不敢用!
    Java 14 之模式匹配,非常赞的一个新特性!
    poj 3661 Running(区间dp)
    LightOJ
    hdu 5540 Secrete Master Plan(水)
    hdu 5584 LCM Walk(数学推导公式,规律)
    hdu 5583 Kingdom of Black and White(模拟,技巧)
    hdu 5578 Friendship of Frog(multiset的应用)
    hdu 5586 Sum(dp+技巧)
    hdu 5585 Numbers
  • 原文地址:https://www.cnblogs.com/muchen/p/6298634.html
Copyright © 2020-2023  润新知