• scikit-learn的GBDT工具进行特征选取。


    http://blog.csdn.net/w5310335/article/details/48972587

    使用GBDT选取特征

    2015-03-31

    本文介绍如何使用scikit-learn的GBDT工具进行特征选取。

    为什麽选取特征


    有些特征意义不大,删除后不影响效果,甚至可能提升效果。

    关于GBDT(Gradient Boosting Decision Tree)


    可以参考:

    GBDT(MART)概念简介

    GBDT(MART) 迭代决策树入门教程 | 简介

    机器学习中的算法(1)-决策树模型组合之随机森林与GBDT

    如何在numpy数组中选取若干列或者行?


    >>> import numpy as np
    >>> tmp_a = np.array([[1,1], [0.4, 4], [1., 0.9]])
    >>> tmp_a
    array([[ 1. ,  1. ],  
           [ 0.4,  4. ],
           [ 1. ,  0.9]])
    >>> tmp_a[[0,1],:]  # 选第0、1行
    array([[ 1. ,  1. ],  
           [ 0.4,  4. ]])
    >>> tmp_a[np.array([True, False, True]), :]  # 选第0、2行
    array([[ 1. ,  1. ],  
           [ 1. ,  0.9]])
    >>> tmp_a[:,[0]]    # 选第0列
    array([[ 1. ],  
           [ 0.4],
           [ 1. ]])
    >>> tmp_a[:, np.array([True, False])]  # 选第0列
    array([[ 1. ],  
           [ 0.4],
           [ 1. ]])
    

    生成数据集


    参考基于贝叶斯的文本分类实战。部分方法在原始数据集的预测效果也在基于贝叶斯的文本分类实战这篇文章里。

    训练GBDT


    >>> from sklearn.ensemble import GradientBoostingClassifier
    >>> gbdt = GradientBoostingClassifier()
    >>> gbdt.fit(training_data, training_labels)  # 训练。喝杯咖啡吧
    GradientBoostingClassifier(init=None, learning_rate=0.1, loss='deviance',  
                  max_depth=3, max_features=None, max_leaf_nodes=None,
                  min_samples_leaf=1, min_samples_split=2,
                  min_weight_fraction_leaf=0.0, n_estimators=100,
                  random_state=None, subsample=1.0, verbose=0,
                  warm_start=False)
    >>> gbdt.feature_importances_   # 据此选取重要的特征
    array([  2.08644807e-06,   0.00000000e+00,   8.93452010e-04, ...,  
             5.12199658e-04,   0.00000000e+00,   0.00000000e+00])
    >>> gbdt.feature_importances_.shape
    (19630,)
    

    看一下GBDT的分类效果:

    >>> gbdt_predict_labels = gbdt.predict(test_data)
    >>> sum(gbdt_predict_labels==test_labels)  # 比 多项式贝叶斯 差许多
    414  
    

    新的训练集和测试集(只保留了1636个特征,原先是19630个特征):

    >>> new_train_data = training_data[:, feature_importances>0]
    >>> new_train_data.shape  # 只保留了1636个特征
    (1998, 1636)
    >>> new_test_data = test_data[:, feature_importances>0]
    >>> new_test_data.shape
    (509, 1636)
    

    使用多项式贝叶斯处理新数据


    >>> from sklearn.naive_bayes import MultinomialNB
    >>> bayes = MultinomialNB() 
    >>> bayes.fit(new_train_data, training_labels)
    MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)  
    >>> bayes_predict_labels = bayes.predict(new_test_data)
    >>> sum(bayes_predict_labels == test_labels)   # 之前预测正确的样本数量是454
    445  
    

    使用伯努利贝叶斯处理新数据


    >>> from sklearn.naive_bayes import BernoulliNB
    >>> bayes2 = BernoulliNB()
    >>> bayes2.fit(new_train_data, training_labels)
    BernoulliNB(alpha=1.0, binarize=0.0, class_prior=None, fit_prior=True)  
    >>> bayes_predict_labels = bayes2.predict(new_test_data)
    >>> sum(bayes_predict_labels == test_labels)   # 之前预测正确的样本数量是387
    422  
    

    使用Logistic回归处理新数据


    对原始特征组成的数据集:

    >>> from sklearn.linear_model import LogisticRegression
    >>> lr1 = LogisticRegression()
    >>> lr1.fit(training_data, training_labels)
    LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,  
              intercept_scaling=1, max_iter=100, multi_class='ovr',
              penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
              verbose=0)
    >>> lr1_predict_labels = lr1.predict(test_data)
    >>> sum(lr1_predict_labels == test_labels)
    446  
    

    对削减后的特征组成的数据集:

    >>> lr2 = LogisticRegression()
    >>> lr2.fit(new_train_data, training_labels)
    LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,  
              intercept_scaling=1, max_iter=100, multi_class='ovr',
              penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
              verbose=0)
    >>> lr2_predict_labels = lr2.predict(new_test_data)
    >>> sum(lr2_predict_labels == test_labels)  # 正确率略微提升
    449  
    

    (完)

  • 相关阅读:
    Transaction 事务简单详解
    JAVA------6.短信配置并返回
    JAVA------5.启动服务端,客户端发送数据,用户端接收数据,string数组转byte字节,CrcUtil校验
    java------4.根据经纬度排序,并计算距离。。。。。。。。根据地址计算出经纬度
    svn------找不到路径
    java------3.时间戳
    服务器------3.根据经纬度划分区域
    php-------1.ie11配置httpWatch9.1.21
    mysql------1.查询当天的所有数据
    html------1.网页mp3语音展示,点击图片放大,点击图片跳转链接,调表格
  • 原文地址:https://www.cnblogs.com/DjangoBlog/p/6211255.html
Copyright © 2020-2023  润新知