• 模型融合策略:开发树模型输出叶子节点作为特征到回归器或者分类器的类


    from sklearn.base import BaseEstimator,ClassifierMixin,RegressorMixin
    from sklearn.preprocessing import OneHotEncoder
    import numpy as np
    
    class TreeLeaf(BaseEstimator,ClassifierMixin,RegressorMixin):
        """
        树模型和其他模型的结合:树模型输出的叶子节点当成特征输入到其他模型中
        """
        def __init__(self,treeModel=[],metaModel=[],n_estimators=[],goal="regression"):
            self.treeModel = treeModel
            self.metaModel = metaModel   
            self.n_estimators = n_estimators
            self.goal = goal
        
        def fit(self,X,y):
            self.best_treemodel = [] #用于保存训练参数后的tree模型   
            self.best_metamodel = [] #用于保存训练参数后的meta模型 
            self.leaf_list  = [] #用于保存叶子节点
            
            for model in self.treeModel:   
                
                model_param = model.fit(X,y) #得到训练参数后的模型
                self.best_treemodel.append(model_param)
                
                leaf = model_param.apply(X)  #输出叶子
                self.leaf_list.append(leaf)
               
            #对叶子节点进行拼接
            leaf_matrix = np.concatenate(self.leaf_list,axis=1)
            
            
            #对叶子节点进行one_hot编码
            self.one_hot_encoder = OneHotEncoder()
            x_one_hot = self.one_hot_encoder.fit_transform(leaf_matrix)
            
            #利用metaModel做拟合                  
            for model in self.metaModel:
                model_param = model.fit(x_one_hot,y)
                self.best_metamodel.append(model_param)
            
            return self
        
        def predict(self,X):
            
            leaf_list_pred = []
            
            for model in self.best_treemodel:            
                leaf_list_pred.append(model.apply(X))
                
            leaf_matrix_pred = np.concatenate(leaf_list_pred,axis=1)    
            
            x_one_hot_pred = self.one_hot_encoder.transform(leaf_matrix_pred)
            
            y_pred_list = []
            for model in self.best_metamodel:
                y_pred_list.append(model.predict(x_one_hot_pred))
            
            if self.goal == "regression":
                return sum(y_pred_list,axis=0)
            elif self.goal == "classification":  
                y_pred = np.zeros(X.shape[0])            
                for i,line in enumerate(np.array(y_pred_list).T):
                    y_pred[i] = np.argmax(np.bincount(line))
                return y_pred
    
    ##################案例测试####################################################
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.datasets import load_iris  
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    from sklearn.metrics import accuracy_score
    from lightgbm import LGBMClassifier
     
    X,y = load_iris(return_X_y=True)  
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=0)
    
    treeModel_1 = RandomForestClassifier(n_estimators=20)
    treeModel_2 = LGBMClassifier( n_estimators=30)
    #treeModel_2 = GradientBoostingClassifier(n_estimators=30)
    
    metaModel_1 = LogisticRegression()
    metaModel_2 = SVC()
    
    tl = TreeLeaf(treeModel=[treeModel_1,treeModel_2],metaModel=[metaModel_1,metaModel_2],n_estimators=[20,30],goal="classification")
    tl.fit(X_train,y_train)
    y_pred = tl.predict(X_test)
    
    accuracy_score(y_test,y_pred)

    上述代码主要完成了基于多个树模型的叶子节点输入到多个分类器或者回归器的模型融合策略,具有一定的扩展性和适应度。后面给出了一个基于随机深林和lightGBM的测试实例,供大家参考。这种模型融合策略在不同的地方效果不同,关键还是特征工程是否做得更好,该类方法在训练集上有一定的过拟合倾向。

    欢迎评论和给出意见,如果对你有帮助,请给个关注,激励一下我,谢谢!

  • 相关阅读:
    MVC模式简单介绍
    Android AES加密算法及事实上现
    01背包问题
    C++继承经典样例
    [React] React Fundamentals: Using Refs to Access Components
    [React] React Fundamentals: Owner Ownee Relationship
    [React] React Fundamentals: State Basics
    [React] React Fundamentals: First Component
    [Javascript] Lodash: Refactoring Simple For Loops (_.find, _.findLast, _.filter)
    [Javascript] Create an Array concatAll method
  • 原文地址:https://www.cnblogs.com/wzdLY/p/9677784.html
Copyright © 2020-2023  润新知