• python sklearn包——grid search笔记


    Preface:算法不够好,需要调试参数时必不可少。比如SVM的惩罚因子C,核函数kernel,gamma参数等,对于不同的数据使用不同的参数,结果效果可能差1-5个点,sklearn为我们提供专门调试参数的函数grid_search。

    在sklearn中以API的形式给出介绍。在离线包中函数较多,但常用为GridSearchCV()这个函数。

    1.GridSearchCV:

    看例子最为容易懂得使用其的方法。

    sklearn包中介绍的例子:

    卤煮直接从官网上贴上例子:grid_search_digits.py

    [python] view plain copy
     
     
     
     在CODE上查看代码片派生到我的代码片
    1. from __future__ import print_function  
    2.   
    3. from sklearn import datasets  
    4. from sklearn.cross_validation import train_test_split  
    5. from sklearn.grid_search import GridSearchCV  
    6. from sklearn.metrics import classification_report  
    7. from sklearn.svm import SVC  
    8.   
    9. print(__doc__)  
    10.   
    11. # Loading the Digits dataset  
    12. digits = datasets.load_digits()  
    13.   
    14. # To apply an classifier on this data, we need to flatten the image, to  
    15. # turn the data in a (samples, feature) matrix:  
    16. n_samples = len(digits.images)  
    17. X = digits.images.reshape((n_samples, -1))  
    18. y = digits.target  
    19.   
    20. # Split the dataset in two equal parts  
    21. X_train, X_test, y_train, y_test = train_test_split(  
    22.     X, y, test_size=0.5, random_state=0)  
    23.   
    24. # Set the parameters by cross-validation  
    25. tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],  
    26.                      'C': [1, 10, 100, 1000]},  
    27.                     {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]  
    28.   
    29. scores = ['precision', 'recall']  
    30.   
    31. for score in scores:  
    32.     print("# Tuning hyper-parameters for %s" % score)  
    33.     print()  
    34.   
    35.     clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5,  
    36.                        scoring='%s_weighted' % score)  
    37.     clf.fit(X_train, y_train)  
    38.   
    39.     print("Best parameters set found on development set:")  
    40.     print()  
    41.     print(clf.best_params_)  
    42.     print()  
    43.     print("Grid scores on development set:")  
    44.     print()  
    45.     for params, mean_score, scores in clf.grid_scores_:  
    46.         print("%0.3f (+/-%0.03f) for %r"  
    47.               % (mean_score, scores.std() * 2, params))  
    48.     print()  
    49.   
    50.     print("Detailed classification report:")  
    51.     print()  
    52.     print("The model is trained on the full development set.")  
    53.     print("The scores are computed on the full evaluation set.")  
    54.     print()  
    55.     y_true, y_pred = y_test, clf.predict(X_test)  
    56.     print(classification_report(y_true, y_pred))  
    57.     print()  

    其中,将参数放在列表中

    tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],
                         'C': [1, 10, 100, 1000]},
                        {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]
    建立分类器clf时,调用GridSearchCV()函数,将上述参数列表的变量传入函数。并且可传入交叉验证cv参数,设置为5折交叉验证。对训练集训练完成后调用best_params_变量,打印出训练的最佳参数组。

    Figure :运行结果

    可以看出,其得出最佳参数组字典,还有每一次用参数组进行训练得出的得分。最后在测试集上,给出10个类别的测试报告,对于类别0,RPF都为1,。。。。这里使用sklearn.metrics下的classification_report()函数即可,输入测试集真实的结果和预测的结果即返回每个类别的准确率召回率F值以及宏平均值。

    对于SVM分类器,这里只列出线性核和RBF核,其中线性核不必用gamma这个参数,RBF核可用不同惩罚值C和不同的gamma值作为组合。上述列出的结果即可看出有哪些组合。这里的结果是RBF核,惩罚项为10,gamma值为0.001效果最佳。卤煮以为RBF核是比较好的,但是在最近的学习中,确实是不一定,用了线性核效果更好些,但选训练非常慢,数据集不一样效果差很多吧,可能。

    另外有个grid_search_text_feature_extraction.py程序写得也很不错,只是卤煮fetch_20newsgroup数据集没有准备好,跑不了

  • 相关阅读:
    使用GitLab搭建Git仓库
    SpringBoot web开发
    springboot配置
    springboot自动配置原理
    springboot修改端口号
    springboot创建方式
    junit运行多个测试的方法
    junit常用注解
    junit断言
    sublime将.m文件关联MATLAB类型高亮
  • 原文地址:https://www.cnblogs.com/King-of-Knight/p/9482769.html
Copyright © 2020-2023  润新知