• cross-validation


    交叉验证

    • 模型拟合的程度好坏取决于数据的划分(主要指训练集和测试集的划分)-
    • 不代表模型的泛化能力datacamp

    Cross-validation is a vital step in evaluating a model. It maximizes the amount of data that is used to train the model, as during the course of training, the model is not only trained, but also tested on all of the available data.
    最大化的选择模型分训练集,使其泛化能力更好
    换句话说,就是选择最佳参数

    基本思想

    交叉验证的基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训练集(train set),另一部分做为验证集(validation set or test set),首先用训练集对分类器进行训练,再利用验证集来测试训练得到的模型(model),以此来做为评价分类器的性能指标。百度百科

    用途

    • 准确的调整模型的超参数(Hyperparameter),且这组参数对不同的数据,表现相对稳定
    • 在某些分类场景,你可以同时使用逻辑回归、决策树或聚类等多种算法建模,当不确定哪种算法效果更好时,可以使用交叉验证

    例子

    • 为了降低测试数据产生的偶然性,更好的做法便是采用「交叉验证」,还是以切分 5 份数据为例,交叉验证的做法是,对于同一个算法,同时训练出 5 个模型,每个模型采用不同的测试数据(例如模型 1 选用第 1 份,模型 2 选用第 2 份,以此类推),在所有模型都完成测试后,再对这 5 个模型的评估结果求平均,便可以得到一个相对稳定且更有说服力的算法。

    • 举个具体的例子,假设我们的模型采用决策树算法,该算法有个超参数是树的深度 height,我们可以将其设为 2,也可以设为 3,但不清楚设哪个数比较好,此时我们就可以使用「交叉验证」来帮我们决策,首先还是将数据 5等分,对每一个参数值,我们都训练 5次,输出 5种可能的测试结果,然后对这5个结果取平均,即可测试出哪个是我们想要的结果。简书

    可用API

    sklearn.model_selection

    5折交叉验证

    # Import the necessary modules
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import cross_val_score
    
    # Create a linear regression object: reg
    reg = LinearRegression()
    
    # Compute 5-fold cross-validation scores: cv_scores
    cv_scores = cross_val_score(reg, X, y, cv=5)
    
    # Print the 5-fold cross-validation scores
    print(cv_scores)
    
    # Print the average 5-fold cross-validation score
    print("Average 5-Fold CV Score: {}".format(np.mean(cv_scores)))
    

    K折交叉验证

    # Import necessary modules
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import cross_val_score
    
    
    # Create a linear regression object: reg
    reg = LinearRegression()
    
    # Perform 3-fold CV
    cvscores_3 = cross_val_score(reg, X, y, cv = 3)
    print(np.mean(cvscores_3))
    
    # Perform 10-fold CV
    cvscores_10 = cross_val_score(reg, X, y, cv = 10)
    print(np.mean(cvscores_10))
    
    <script.py> output:
        0.8718712782622108
        0.8436128620131201
    
  • 相关阅读:
    易耗品管理 第三四表 查询的存储过程
    [zz]使用vc编译libsvm
    程序调试小bug
    Ubuntu下安装配置OpenNI, OpenCV
    关于Linux下使用OpenCv读取视频打不开的问题
    jQuery实现图片点击放大
    关于 QtDBus 的种种
    javascript计时器的实现
    [QT]没有选择Debug构建方式.为文件的某行设置断点可能会失败
    linux firefox 不显示英文的解决
  • 原文地址:https://www.cnblogs.com/gaowenxingxing/p/12302890.html
Copyright © 2020-2023  润新知