• 开发记录3


    上次完成了关键字的提取,这一次就实现自动分类

    在实现自动分类的时候,我在晚上找了很多关于自动分类的方法,找了关于spark,关于python的,java的等等都比较乱

    然后我又在网上找了基于python的机器学习,可以自动对内容进行自动分类,代码如下:

    #!/usr/bin/env python
    # coding=utf-8
    import sys
    import jieba
    from sklearn.pipeline import Pipeline
    from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
    from sklearn.svm import LinearSVC
    from sklearn.multiclass import OneVsRestClassifier
    from sklearn.preprocessing import MultiLabelBinarizer
    import pymysql
    import pandas as pd
    import re
    import numpy as np
    
    
    def jieba_tokenizer(x): return jieba.cut(x, cut_all=True)
    
    
    def partition(x): return x
    
    
    def filter_html(s):
        d = re.compile(r'<[^>]+>', re.S)
        s = d.sub('', s)
        return s
    
    
    def gbk_utf8(s):
        s = s.decode('gbk', "ignore").encode('utf8')
        return s
    def write_sql(id,classs):
        db = pymysql.Connection(host="localhost", port=3306, user="root", password="root", database="dazuoye",
                                charset="utf8")
        cursor = db.cursor()
        sql = "update info_tech set type='" + classs + "'  where index=" + str(id)
        try:
            cursor.execute(sql)
            db.commit()
        except:
            db.commit()
            print("出错了!")
        db.close()
    
    
    # 链接mysql数据库
    conn = pymysql.Connection(host="localhost",port=3306,user="root", password="root",database="dazuoye",charset="utf8")
    cursor = conn.cursor()
    cursor=conn.cursor()
    
    # 训练数据样本
    data_ret = pd.DataFrame()
    
    sql = "SELECT index, title3,type,content FROM info_tech "
    # print sql
    cursor.execute(sql)
    
    txt_ret = []
    #class_ret = [["信息化"],["大数据"],["云计算"],["区块链"],["智慧城市"],["工业互联网"],["信息安全"],["操作系统"],["计算机"],["法律法规"],["信息化战略"]]
    class_ret=[]
    id_ret = []
    for row in cursor.fetchall():
        content = filter_html(row[3])
        txt_ret.append(content)
        class_s = row[2]
        class_l = class_s.split("  ")
        class_ret.append(class_l)
        id_ret.append(row[0])
    
    txt_ret = txt_ret
    
    X_train = txt_ret
    print(class_ret)
    Y_train = class_ret
    
    classifier = Pipeline([
        ('counter', CountVectorizer(tokenizer=jieba_tokenizer)),
        ('tfidf', TfidfTransformer()),
        ('clf', OneVsRestClassifier(LinearSVC())),
    ])
    mlb = MultiLabelBinarizer()
    Y_train = mlb.fit_transform(Y_train)
    classifier.fit(X_train, Y_train)
    print(classifier.score(X_train,Y_train))
    # 测试数据
    test_txt_set = []
    sql = "SELECT index, title3,keyword,content FROM info_tech "
    cursor.execute(sql)
    test_id_ret = []
    
    for row in cursor.fetchall():
        test_txt_set.append(filter_html(row[3]))
        test_id_ret.append(row[0])
    X_test = test_txt_set
    
    prediction = classifier.predict(X_test)
    
    result = mlb.inverse_transform(prediction)
    # 展示结果
    for i, label1 in enumerate(result):
        classstr = ''
        for j, label2 in enumerate(label1):
            classstr += str(label2) + ""
        print("ID:" + str(test_id_ret[i]) + " =>class:" + classstr)
        write_sql(test_id_ret[i],classstr)

    参考教程:https://morvanzhou.github.io/tutorials/machine-learning/sklearn/

    曾请教:王莉

  • 相关阅读:
    STL————vector的用法
    DFS,DP————N皇后问题
    DP经典问题—————(LCIS)最长公共上升子序列
    DP————LIS(最长上升子序列)和LCS(最长公共子序列)问题
    CentOS7使用firewalld打开关闭防火墙与端口
    CentOS7下安装MySQL5.7安装与配置(YUM)
    nginx + tomcat +redis 负载均衡遇到问题集锦
    centos 7 安装 tomcat
    centos 7 设置防火墙 开放指定端口
    centos 7 通过yum 安装 nginx
  • 原文地址:https://www.cnblogs.com/lovema1210/p/10666117.html
Copyright © 2020-2023  润新知