• Scikit-learn 多标签分类 multilabel classification(大量训练数据,MultiOutputClassifier,partial_fit)


    核心代码:

    # from sklearn.linear_model import LogisticRegression
    from sklearn.multioutput import MultiOutputClassifier
    from sklearn.naive_bayes import MultinomialNB
    from utils.data_util import load_pickle
    import os
    from pathConfig import data_dir
    from utils.vocab_util import vocab_to_index_dict
    import numpy as np
    
    # train & test data
    train_dir = os.path.join(data_dir, "train")
    test_dir = os.path.join(data_dir, "test")
    
    # train
    # classifier = SVC(kernel='linear', probability=True)
    # classifier = LogisticRegression()
    classifier = MultinomialNB()
    print("Training classifier ", str(classifier))
    clf = MultiOutputClassifier(classifier, n_jobs=24)
    
    for fname in os.listdir(train_dir):
        fpath = os.path.join(train_dir, fname)
        print("loading file ", fpath)
        train_X, train_y = load_train_file(fpath)
        print("partial_fiting...")
        clf.partial_fit(train_X, train_y, classes=[[0, 1]] * len(label_vocab))
        break
    
    # test
    test_X, test_y = load_test_data()
    
    # evaluate for each test file
    y_pred = clf.predict_proba(test_X)  # [n_tags, n_test_unit]
    
    y_pred_prcessed = []
    for i in range(len(test_X)):
        test_tmp = []
        for j in range(len(tag_vocab)):
            test_tmp.append(y_pred[j][i][0] * 0.5 + y_pred[j][i][1] * 0.5)  # because [0,1]
        y_pred_prcessed.append(np.array(test_tmp))
    y_pred_prcessed = np.array(y_pred_prcessed)
    
  • 相关阅读:
    客户端不能连接MySQL
    Linux 7.x 防火墙&端口
    MYSQL.版本查看-LINUX
    Java之.jdk卸载-Linux
    Redis.之.环境搭建(集群)
    Elasticsearch.安装插件(head)
    Linux安装Nodejs
    Linux.ls 查看常用参数
    Elasticsearch.安装(单节点)
    Andrew NG 机器学习编程作业3 Octave
  • 原文地址:https://www.cnblogs.com/XBWer/p/13503796.html
Copyright © 2020-2023  润新知