• 机器学习笔记:sklearn交叉验证之KFold与StratifiedKFold


    一、交叉验证

    机器学习中常用交叉验证函数:KFoldStratifiedKFold

    方法导入:

    from sklearn.model_selection import KFold, StratifiedKFold
    
    • StratifiedKFold:采用分层划分的方法(分层随机抽样思想),验证集中不同类别占比与原始样本的比例一致,划分时需传入标签特征
    • KFold:默认随机划分训练集、验证集

    二、KFold交叉验证

    1.使用语法

    sklearn.model_selection.KFold(n_splits=3, # 最少2折
                                 shuffle=False, # 是否打乱
                                 random_state=None)
    

    2.实操

    • get_n_splits -- 返回折数
    • split -- 切分
    import numpy as np
    from sklearn.model_selection import KFold, StratifiedKFold
    
    X = np.array([[1,2], [3,4], [5,6], [7,8]])
    y = np.array([1,2,3,4])
    kf = KFold(n_splits=2)
    kf.get_n_splits() # 2
    print(kf) # KFold(n_splits=2, random_state=None, shuffle=False)
    
    # 此处的split只需传入数据,不需要传入标签
    for train_index, test_index in kf.split(X):
        print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
    '''
    TRAIN: [2 3] TEST: [0 1]
    TRAIN: [0 1] TEST: [2 3]
    '''
    

    三、StratifiedKFold交叉验证

    1.使用语法

    sklearn.model_selection.StratifiedKFold(n_splits=3, # 同KFold参数
                                           shuffle=False,
                                           random_state=None)
    

    2.实操

    import numpy as np
    from sklearn.model_selection import KFold, StratifiedKFold
    
    X = np.array([[1,2], [3,4], [5,6], [7,8]])
    y = np.array([1,0,0,1])
    skf = StratifiedKFold(n_splits=2)
    skf.get_n_splits() # 2
    print(skf) # StratifiedKFold(n_splits=2, random_state=None, shuffle=False)
    
    # 同时传入数据集和标签
    for train_index, test_index in skf.split(X, y):
        print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
    

    注意:拆分的折数必须大于等于标签类别,否则报错:

    ValueError: n_splits=2 cannot be greater than the number of members in each class.
    

    参考链接:sklearn.model_selection.KFold

    参考链接:sklearn.model_selection.StratifiedKFold

    参考链接:python sklearn中KFold与StratifiedKFold

  • 相关阅读:
    c++11之智能指针
    SurfaceFlinger与Surface概述
    android GUI 流程记录
    文章收藏
    android performance
    POJ3349
    java中的volatile和synchronized
    [原创]分期还款的名义利率与真实利率
    Java IO 流总结
    telegram
  • 原文地址:https://www.cnblogs.com/hider/p/15948302.html
Copyright © 2020-2023  润新知