• matlab实现cart(回归分类树)


    作为机器学习的小白和matlab的小白自己参照 python的 《机器学习实战》 写了一下分类回归树,这里记录一下。

    关于决策树的基础概念就不过多介绍了,至于是分类还是回归。。我说不清楚。。我用的数据集是这个http://archive.ics.uci.edu/ml/datasets/Abalone 就是通过一些属性来预测鲍鱼有多少头,下面看一下

    Length / continuous / mm / Longest shell measurement 
    Diameter / continuous / mm / perpendicular to length 
    Height / continuous / mm / with meat in shell 
    Whole weight / continuous / grams / whole abalone 
    Shucked weight / continuous / grams / weight of meat 
    Viscera weight / continuous / grams / gut weight (after bleeding) 
    Shell weight / continuous / grams / after being dried 
    Rings / integer / -- / +1.5 gives the age in years

    这些属性除了最后的Rings是整数,可以看做是离散的,其他都是浮点数,是连续的。所以还是用cart中二分的思想,就是小于等于分一边,大于分一边。但是没有用gini指数,因为熵还是好一点。

    代码在github:https://github.com/jokermask/matlab_cart

    参照《机器学习实战》代码有5个部分:getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集),chooseBestFeatureToSplit(寻找最佳分割点和阈值),createTree(建树),predict(预测)。

    我按流程梳理一下,首先函数脚本来将数据集划分成,训练集和测试集,然后用训练集建树,用测试集测试,(更改后变成bootstrap sampleing)

    dataset = importdata('abalone.data.txt') ;
    origin_data = dataset.data ;
    labels = {'Length';'Diam';'Height';    'Whole';'Shucked';'Viscera';'Shell';'Rings'} ;
    test_runtimes = 50 ;
    ae = 0 ;
    rr = 0 ;
    for i=1:test_runtimes
        data = sampleWithReplace(origin_data) ;%bootstrap sampling
        len = floor(length(data)/4*3) ;
        train_data = data(1:len,:) ;
        test_data = data(len:end,:) ;
        test_y_truth = test_data(:,end) ;
    %     tree = createTree(train_data,labels,0) ;
    %     predict_y = predict(tree,test_data,labels) ;
    %     com_matrix = [predict_y,test_y_truth] ;
    %     count = sum(predict_y==test_y_truth) ;
    %     disp(com_matrix) ;
    %     disp(mae) ;
    %     disp(rr) ;
    
    %plot single runtime
    %     x = 1:1:size(test_y_truth,1) ;
    %     plot(x,predict_y,'-b',x,test_y_truth,'-r') ;
    
        ae = ae+sum(abs(predict_y-test_y_truth))/size(test_y_truth,1) ;
        rr = rr+count/size(test_y_truth,1) ;
        
        %trian with office tools fitctree
        
        std_tree = fitctree(train_data(:,1:7),train_data(:,end)) ;
        % view(std_tree) ;
        std_y = predict(std_tree,test_data(:,1:7)) ;
        % disp([std_y,y]) ;
        ae = ae+sum(abs(std_y-test_y_truth))/size(test_y_truth,1) ;
        rr = rr+sum(std_y==test_y_truth)/size(test_y_truth,1) ;
    end
    mae = mae / test_runtimes ;
    mrr = rr / test_runtimes ;
    disp('mae') ;
    disp(mae) ;
    disp('mrr') ;
    disp(mrr) ;

    createTree函数:由于matlab没有指针,所以只能写成嵌套结构,就像tree{tree{tree}}这样。我们是递归实现的,但怎么样才会停止建树?条件是当前节点所有标签的类别一样,比如rings都为10,那说明这一个子集已经纯了,或者是这颗树的高度已经超出了我们设的阈值,就停止,第二种情况很可能当前节点下的数据集不纯,我们就找一个出现频率最高的类别代表该节点

    function [ tree ] = createTree( dataset,labels,heightcount )
        len = size(dataset,1) ;
        templabel = dataset(1,end) ;
        tree = templabel ;
        max_depth = 5 ;%最大树高
        flag = 1 ; %判断是否数据集中所有标签都一致了(纯的),是则返回
        for i=1:len
            if templabel~=dataset(i,end) ;
                flag = 0 ;
            end
        end
        if flag==1
            return ;
        end
        if heightcount>max_depth
            labelVec = dataset(:,end) ;
            disp(labelVec) ;
            element = 1:max(labelVec) ;
            counts = histc(labelVec,element) ;
            [~,max_idx] = max(counts) ;
            tree = element(max_idx) ;
            return ;
        end
        [bestFeat,bestT] = chooseBestFeatureToSplit(dataset) ;
        bestFeatLabel = labels{bestFeat} ;
        tree = struct ;%struct储存树结构
        tree.bestFeatLabel = bestFeatLabel ;
        tree.bestT = bestT ;
        tree.greaterthan = createTree(splitDataset(dataset,bestFeat,bestT,1),labels,heightcount+1) ;%大于阈值部分的子树
        tree.lessthan = createTree(splitDataset(dataset,bestFeat,bestT,2),labels,heightcount+1) ;%小于阈值部分的子树
    end

    chooseBestFeatureToSplit函数:在createTree时,每次递归都要找那个当前最佳的特征和阈值,也就是调用chooseBestFeatureToSplit函数,所以两层循环,第一层遍历每个属性,第二层本应该遍历每个属性下的值,但是那样计算量太大了,所以我就将值排序之后分成10端取中位数遍历,在里面找阈值,如果当前节点的数据子集已经不足10个里,那就把所有属性都遍历一哈

    function [ bestFeat,bestT ] = chooseBestFeatureToSplit( dataset )
        [~,numFeats] = size(dataset) ;
        numFeats = numFeats-1 ;%除去标签那一列
        baseEnt = getEnt(dataset) ;
        baseInfoGain = 0 ;
        bestFeat = -1 ;
        for i=1:numFeats
            featVec = dataset(:,i) ;
            %由于值是连续的,所以对于特征向量组排序分成n段取中位数
            sortedFeatVec = sort(featVec,'ascend') ;
            lengthofT = floor(sqrt(length(sortedFeatVec))) ; %取向量长度开根号来确定阈值的个数
            if lengthofT<10
                lengthofT = length(sortedFeatVec) ;
                selectedFeat = sortedFeatVec ;
            else
                step = floor(length(sortedFeatVec)/lengthofT) ;
                selectedFeat = zeros(lengthofT,1) ;
                for j=1:lengthofT
                    head = (j-1)*step+1 ;
                    tail = j*step ;
                    subSortedFeatVec = sortedFeatVec(head:tail) ;
                    selectedFeat(j) = median(subSortedFeatVec) ;
                end
            end
            for k=1:lengthofT
                newEnt = 0 ;
                for l=1:2
                    subDataset = splitDataset(dataset,i,selectedFeat(k),l) ;
                    prob = size(subDataset,1)/size(dataset,1) ;
                    newEnt = newEnt + prob*getEnt(subDataset) ;
                end
                infoGain = baseEnt - newEnt ;
    %             disp('infoGain') ;
    %             disp(infoGain) ;
                if(infoGain>baseInfoGain)
                    baseInfoGain = infoGain ;
                    bestFeat= i ;
                    bestT = selectedFeat(k) ;
                end
            end
        end
    end

    计算信息增益(infoGain)的时候需要用到getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集)函数

    splitDataset:

    function [ retDataset ] = splitDataset(dataset,axis,value,arg )
    %axis 代表键值的位置 value表示阈值 返回划分后的dataset,arg表示取大于的部分(1)还是小于等于的部分
        if arg==1
            retDataset = dataset(dataset(:,axis)>value,:) ;
        else
            retDataset = dataset(dataset(:,axis)<=value,:) ;
        end
    end
    View Code

    getEnt:

    function [ ent ] = getEnt( data )
    %index present the label
    [datalen,~] = size(data) ;
    maxLabel = max(data(:,end)) ;
    labelCountsMap = zeros(maxLabel,1) ;%rings are all numbers
        for i=1:datalen
            label =  data(i,end) ;
            if labelCountsMap(label)~=0
                labelCountsMap(label) = labelCountsMap(label) + 1 ;
            else
                labelCountsMap(label) = 1 ; 
            end
        end
        ent = 0 ;
    %     disp('labelMap') ;
    %     disp(labelCountsMap) ;
        for i=1:maxLabel
            if labelCountsMap(i)~=0
                prob = labelCountsMap(i)/datalen ;
                ent = ent - prob*log2(prob) ;
            end
        end
    end
    View Code

    最后预测函数:

    function [ classVec ] = predict( tree , dataset , labels)
    %tree应由createTree函数生成
        len = size(dataset,1) ;
        classVec = zeros(len,1) ;
        for i=1:len
            dataVec = dataset(i,1:end-1) ;
            tempnode = tree ;
            while(isstruct(tempnode))
                [~,tempFeatIdx] = ismember(tempnode.bestFeatLabel,labels) ;
                if(dataVec(tempFeatIdx)>tempnode.bestT)
                    tempnode = tempnode.greaterthan ;
                else
                    tempnode = tempnode.lessthan ;
                end
            end
            classVec(i) = tempnode ;
        end
    end
    View Code

    更新了一下代码,加入了boostrap采样,就是有放回的采样,我是这样采用的,有多少个样本就进行多少次有放回采样,然后这个过程进行50次求均值。用了之后,官方的库正确率道理44%,而我的还在30%。。差距一下突显,还需继续学习。。

    补充一下那个sampleWithReplace函数

    function [ sample_data ] = sampleWithReplace( dataset )
        len = size(dataset,1) ;
        randidx = randsample(len,len,true) ;
        sample_data = dataset(randidx,:) ;
    end
  • 相关阅读:
    SEO网站优化10大要点
    三维翻动效果的jquery特效代码
    多款国外虚拟主机简单比较
    jquery同步调用ajax
    3D虚拟技术
    最简单jquery.ajax+php例子(对话框显示文本框输入内容),以小见大(初学手记)
    正则表达式学习博客
    关于XHTML头部声明,什么是DOCTYPE?
    Iframe高度自适应(兼容IE/Firefox、同域/跨域)
    3D立体产业链的发展现状和趋势
  • 原文地址:https://www.cnblogs.com/maskmtj/p/6589584.html
Copyright © 2020-2023  润新知