• 【AdaBoost算法】弱分类器训练过程


    一、加载数据(正样本、负样本特征)

    def loadSimpData():
        #样本特征
        datMat = matrix([[ 1. ,  2.1,  0.3],
                         [ 2. ,  1.1,  0.4],
                         [ 1.3,  1. ,  1.2],
                         [ 1. ,  1. ,  1.1],
                         [ 2. ,  1. ,  1.3],
                         [ 7. ,  2. ,  0.35]])
        #正负样本标志
        classLabels = [1.0, 1.0, 1.0, -1.0, -1.0, -1.0]       
        return datMat,classLabels

    如上,总共有6个训练样本(前三个为正样本,后三个为负样本),每个样本总共有3个特征,以上6个样本的特征值如下:

    正样本1:[ 1. ,  2.1,  0.3]

    正样本2:[ 2. ,  1.1,  0.4]

    正样本3:[ 1.3,  1. ,  1.2]

    负样本1:[ 1. ,  1. ,  1.1]

    负样本2:[ 2. ,  1. ,  1.3]

    负样本3:[ 7. ,  2. ,  0.35]

    二、训练一个弱分类器(选出一个特征和其对应的阈值)

    训练弱分类器的过程就是从已有的特征中选出一个特征以及其对应的阈值,使样本分错的错误率最低,即寻找一个最小分错率的过程。

    1. 最小错误率初始化为无穷大;
    2. 遍历样本的所有特征(本例子每个样本有三个特征,即遍历这三个特征值);
    3. 求出该特征值步长(不同特征不一样),(最大特征值-最小特征值)/步长移动次数,如本例,假设步长移动次数为10,则第一个特征步长为(7-1)/10 = 0.6;
    4. 根据特征值步长开始从最小特征值遍历到最大特征值;
    5. 遍历判断符号,大于还是小于;
    6. 计算出阈值(根据最小特征值及步长),根据阈值、符号、及特征索引、开始对样本分类;
    7. 根据每个样本权重以及分类结果计算分错率,若该分错率小于最小分错率,则更新最小分错率;
    8. 返回最小分错率下的特征索引、符号、阈值,即得到弱分类器。

    代码实现如下:

    def buildStump(datMat,classLabels,D):
        dataMatrix = mat(datMat); labelMat = mat(classLabels).T
        m,n = shape(dataMatrix)
        numSteps = 10.0; bestStump = {}; bestClasEst = mat(zeros((m,1)))
        minError = inf #最小错误率初始化为无穷大
        for i in range(n):
            rangeMin = dataMatrix[:,i].min(); rangeMax = dataMatrix[:,i].max();
            
            stepSize = (rangeMax-rangeMin)/numSteps
            for j in range(-1,int(numSteps)+1):
                for inequal in ['lt', 'gt']: 
                    threshVal = (rangeMin + float(j) * stepSize)
                    
                    predictedVals = stumpClassify(dataMatrix,i,threshVal,inequal)
                    errArr = mat(ones((m,1)))
                    errArr[predictedVals == labelMat] = 0
                    weightedError = D.T*errArr  
                    
                    if weightedError < minError:
                        minError = weightedError
                        bestClasEst = predictedVals.copy()
                        bestStump['dim'] = i
                        bestStump['thresh'] = threshVal
                        bestStump['ineq'] = inequal
        return bestStump,minError,bestClasEst

    三、训练结果

    弱分类器结果:

    特征索引:0

    符号:大于

    阈值:1.6000000000000001


    最小分错率:

    0.33333333(可见单独一个弱分类器在以上样本中无法做到完全分对)


    分类结果:

    [ 1.]
    [-1.](分错)
    [ 1.]
    [ 1.]  (分错)
    [-1.]
    [-1.]

  • 相关阅读:
    PHP中的PEAR是什么?
    Cookie禁用了,Session还能用吗?原因详解
    php中echo、print、print_r、var_dump、var_export区别
    超强汇总!110 道 Python 面试笔试题
    九种跨域方式实现原理
    在MySQL中如何使用覆盖索引优化limit分页查询
    Laravel大型项目系列教程(五)之文章和标签管理
    Bootstrap-tagsinput标系统使用心得
    bootstrap-datepicker使用
    谭安林:大数据在教育行业的研究与应用
  • 原文地址:https://www.cnblogs.com/chenpi/p/5128235.html
Copyright © 2020-2023  润新知