• 机器学习之路:python 网格搜索 并行搜索 GridSearchCV 模型检验方法


    git:https://github.com/linyi0604/MachineLearning

    如何确定一个模型应该使用哪种参数?

    k折交叉验证:
    将样本分成k份
    每次取其中一份做测试数据 其他做训练数据
    一共进行k次训练和测试
    用这种方式 充分利用样本数据,评估模型在样本上的表现情况


    网格搜索:
    一种暴力枚举搜索方法
    对模型参数列举出集中可能,
    对所有列举出的可能组合进行模型评估
    从而找到最好的模型参数

    并行搜索:
    由于每一种参数组合互相是独立不影响的
    所有可以开启多线程进行网格搜索
    这种方式为并行搜索



    python实现的代码:
      1 from sklearn.datasets import fetch_20newsgroups
      2 from sklearn.cross_validation import train_test_split
      3 import numpy as np
      4 from sklearn.svm import SVC
      5 from sklearn.feature_extraction.text import TfidfVectorizer
      6 from sklearn.pipeline import Pipeline
      7 from sklearn.grid_search import GridSearchCV
      8 
      9 # 博文: http://www.cnblogs.com/Lin-Yi/p/9000989.html
     10 
     11 '''
     12 如何确定一个模型应该使用哪种参数?
     13 
     14 k折交叉验证:
     15    将样本分成k份
     16    每次取其中一份做测试数据 其他做训练数据 
     17    一共进行k次训练和测试
     18    用这种方式 充分利用样本数据,评估模型在样本上的表现情况
     19    
     20    
     21 网格搜索:
     22     一种暴力枚举搜索方法
     23     对模型参数列举出集中可能,
     24     对所有列举出的可能组合进行模型评估
     25     从而找到最好的模型参数
     26     
     27 并行搜索:
     28     由于每一种参数组合互相是独立不影响的
     29     所有可以开启多线程进行网格搜索
     30     这种方式为并行搜索
     31 
     32 '''
     33 
     34 # 联网获取所有想你问数据
     35 news = fetch_20newsgroups(subset="all")
     36 # 分割训练数据和测试数据
     37 x_train, x_test, y_train, y_test = train_test_split(news.data[:3000],
     38                                                     news.target[:3000],
     39                                                     test_size=0.25,
     40                                                     random_state=33)
     41 
     42 # 使用pipeline简化系统搭建流程
     43 clf = Pipeline([("vect", TfidfVectorizer(stop_words="english", analyzer="word")), ("svc", SVC())])
     44 
     45 # 这里要实验的超参数有两个  4个svg__gama 和 3个svg__C 一共12种组合
     46 # np.logspace(start, end, num) 从10^start 到 10^end 创建num个数的等比数列
     47 parameters = {"svc__gamma": np.logspace(-2, 1, 4), "svc__C": np.logspace(-1, 1, 3)}
     48 
     49 # 网格搜索
     50 # 创建一个网格搜索: 12组参数组合, 3折交叉验证
     51 gs = GridSearchCV(clf, parameters, verbose=2, refit=True, cv=3)
     52 # 设置n_jobs=-1 表示占用所有cpu开线程   5表示开启5个同步任务
     53 # windows下不支持fork开启线程 所有 linux unix mac 可以用该api
     54 # gs = GridSearchCV(clf, parameters, verbose=2, refit=True, cv=3, n_jobs=-1)
     55 
     56 
     57 # 执行单线程网格搜索
     58 time_ = gs.fit(x_train, y_train)
     59 print(time_)
     60 print(gs.best_params_, gs.best_score_)
     61 # 输出最佳模型在测试机和上的准确性
     62 print(gs.score(x_test, y_test))
     63 '''
     64 Fitting 3 folds for each of 12 candidates, totalling 36 fits
     65 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
     66 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.3s
     67 [Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    8.3s remaining:    0.0s
     68 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
     69 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.5s
     70 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
     71 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.5s
     72 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
     73 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.4s
     74 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
     75 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.5s
     76 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
     77 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.5s
     78 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
     79 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.4s
     80 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
     81 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.6s
     82 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
     83 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.6s
     84 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
     85 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.5s
     86 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
     87 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.6s
     88 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
     89 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.7s
     90 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
     91 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.3s
     92 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
     93 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.4s
     94 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
     95 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.5s
     96 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
     97 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.3s
     98 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
     99 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.4s
    100 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
    101 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.5s
    102 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
    103 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.5s
    104 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
    105 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.6s
    106 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
    107 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.7s
    108 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
    109 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.5s
    110 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
    111 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.6s
    112 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
    113 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.7s
    114 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
    115 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.4s
    116 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
    117 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.4s
    118 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
    119 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.7s
    120 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
    121 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
    122 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
    123 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
    124 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
    125 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
    126 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
    127 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   8.5s
    128 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
    129 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   8.6s
    130 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
    131 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   9.3s
    132 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
    133 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.8s
    134 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
    135 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.9s
    136 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
    137 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.7s
    138 
    139 12组超参数 3折交叉验证 共36个搜索项 花费5.2分钟
    140 [Parallel(n_jobs=1)]: Done  36 out of  36 | elapsed:  5.2min finished
    141 
    142 最佳参数   最佳训练得分
    143 {'svc__C': 10.0, 'svc__gamma': 0.1} 0.7906666666666666
    144 最佳模型的测试得分
    145 0.8226666666666667
    146 
    147 '''


  • 相关阅读:
    CMDB运维开发项目
    Twisted使用和scrapy源码剖析
    scrapy爬虫框架
    rabbitmq:centos7安装与python调用
    github创建远程仓库
    git使用
    Python模块:paramiko
    centos7安装python3和Django后,ModuleNotFoundError: No module named '_sqlite3'
    21. java面向对象
    20. java面向对象
  • 原文地址:https://www.cnblogs.com/Lin-Yi/p/9000989.html
Copyright © 2020-2023  润新知