• 使用Python 2.7实现的垃圾短信识别器


      最近参加比赛,写了一个垃圾短信识别器,在这里做一下记录。

      官方提供的数据是csv文件,其中训练集有80万条数据,测试集有20万条数据,训练集的格式为:行号 标记(0为普通短信,1为垃圾短信) 短信内容;测试集的格式为: 行号 短信内容;要求输出的数据格式要求为: 行号 标记,以csv格式保存。

      实现的原理可概括为以下几步:

        1.读取文件,输入数据

        2.对数据进行分割,将每一行数据分成行号、标记、短信内容。由于短信内容中可能存在空格,故不能简单地用split()分割字符串,应该用正则表达式模块re进行匹配分割。

        3.将分割结果存入数据库(MySQL),方便下次测试时直接从数据库读取结果,省略步骤。

        4.对短信内容进行分词,这一步用到了第三方库结巴分词:https://github.com/fxsjy/jieba

        5.将分词的结果用于训练模型,训练的算法为朴素贝叶斯算法,可调用第三方库Scikit-Learn:http://scikit-learn.org/stable 

        6.从数据库中读取测试集,进行判断,输出结果并写入文件。

      最终实现出来一共有4个py文件:

        1.ImportIntoDB.py 将数据进行预处理并导入数据库,仅在第一次使用。

        2.DataHandler.py 从数据库中读取数据,进行分词,随后处理数据,训练模型。

        3.Classifier.py 从数据库中读取测试集数据,利用训练好的模型进行判断,输出结果到文件中。

        4.Main.py 程序的入口

     

      最终程序每次运行耗时平均在260秒-270秒之间,附代码:

      ImportIntoDB.py:

     1 # -*- coding:utf-8 -*-
     2 __author__ = 'Jz'
     3 
     4 import MySQLdb
     5 import codecs
     6 import re
     7 import time
     8 
     9 # txt_path = 'D:/coding_file/python_file/Big Data/trash message/train80w.txt'
    10 txt_path = 'D:/coding_file/python_file/Big Data/trash message/test20w.txt'
    11 
    12 # use regular expression to split string into parts
    13 # split_pattern_80w = re.compile(u'([0-9]+).*?([01])(.*)')
    14 split_pattern_20w = re.compile(u'([0-9]+)(.*)')
    15 
    16 txt = codecs.open(txt_path, 'r')
    17 lines = txt.readlines()
    18 start_time = time.time()
    19 
    20 #connect mysql database
    21 con = MySQLdb.connect(host = 'localhost', port = 3306, user = 'root', passwd = '*****', db = 'TrashMessage', charset = 'UTF8')
    22 cur = con.cursor()
    23 
    24 # insert into 'train' table
    25 # sql = 'insert into train(sms_id, sms_type, content) values (%s, %s, %s)'
    26 # for line in lines:
    27 #     match = re.match(split_pattern_80w, line)
    28 #     sms_id, sms_type, content = match.group(1), match.group(2), match.group(3).lstrip()
    29 #     cur.execute(sql, (sms_id, sms_type, content))
    30 #     print sms_id
    31 # # commit transaction
    32 # con.commit()
    33 
    34 # insert into 'test' table
    35 sql = 'insert into test(sms_id, content) values (%s, %s)'
    36 for line in lines:
    37     match = re.match(split_pattern_20w, line)
    38     sms_id, content = match.group(1), match.group(2).lstrip()
    39     cur.execute(sql, (sms_id, content))
    40     print sms_id
    41 # commit transaction
    42 con.commit()
    43 
    44 cur.close()
    45 con.close()
    46 txt.close()
    47 end_time = time.time()
    48 print 'time-consuming: ' + str(end_time - start_time) + 's.'

      DataHandler.py:

     1 # -*- coding:utf-8 -*-
     2 __author__ = 'Jz'
     3 
     4 import MySQLdb
     5 import jieba
     6 import re
     7 
     8 class DataHandler:
     9     def __init__(self):
    10         try:
    11             self.con = MySQLdb.connect(host = 'localhost', port = 3306, user = 'root', passwd = '*****', db = 'TrashMessage', charset = 'UTF8')
    12             self.cur = self.con.cursor()
    13         except MySQLdb.OperationalError, oe:
    14             print 'Connection error! Details:', oe
    15 
    16     def __del__(self):
    17         self.cur.close()
    18         self.con.close()
    19 
    20     # obsolete function
    21     # def getConnection(self):
    22     #     return self.con
    23 
    24     # obsolete function
    25     # def getCursor(self):
    26     #     return self.cur
    27 
    28     def query(self, sql):
    29         self.cur.execute(sql)
    30         result_set = self.cur.fetchall()
    31         return result_set
    32 
    33     def resultSetTransformer(self, train, test):
    34         # list of words divided by jieba module after de-duplication
    35         train_division = []
    36         test_division = []
    37         # list of classification of each message
    38         train_class = []
    39         # divide messages into words
    40         for record in train:
    41             train_class.append(record[1])
    42             division = jieba.cut(record[2])
    43             filtered_division_set = set()
    44             for word in division:
    45                 filtered_division_set.add(word + ' ')
    46             division = list(filtered_division_set)
    47             str_word = ''.join(division)
    48             train_division.append(str_word)        
    49 
    50         # handle test set in a similar way as above
    51         for record in test:
    52             division = jieba.cut(record[1])
    53             filtered_division_set = set()
    54             for word in division:
    55                 filtered_division_set.add(word + ' ')
    56             division = list(filtered_division_set)
    57             str_word = ''.join(division)
    58             test_division.append(str_word)
    59 
    60         return train_division, train_class, test_division

      Classifier.py:

     1 # -*- coding:utf-8 -*-
     2 __author__ = 'Jz'
     3 
     4 from DataHandler import DataHandler
     5 from sklearn.feature_extraction.text import TfidfVectorizer
     6 from sklearn.feature_extraction.text import TfidfTransformer
     7 from sklearn.feature_extraction.text import CountVectorizer
     8 from sklearn.naive_bayes import MultinomialNB
     9 import time
    10 
    11 class Classifier:
    12     def __init__(self):
    13         start_time = time.time()
    14         self.data_handler = DataHandler()
    15         # get result set
    16         self.train = self.data_handler.query('select * from train')
    17         self.test = self.data_handler.query('select * from test')
    18         self.train_division, self.train_class, self.test_division = self.data_handler.resultSetTransformer(self.train, self.test)
    19         end_time = time.time()
    20         print 'Classifier finished initializing, time-consuming:' + str(end_time - start_time) + 's.'
    21 
    22     def getMatrices(self):
    23         start_time = time.time()
    24         # convert a collection of raw documents to a matrix of TF-IDF features.
    25         self.tfidf_vectorizer = TfidfVectorizer()
    26         # learn vocabulary and idf, return term-document matrix [sample, feature]
    27         self.train_count_matrix = self.tfidf_vectorizer.fit_transform(self.train_division)
    28         # transform the count matrix of the train set to a normalized tf-idf representation 
    29         self.tfidf_transformer = TfidfTransformer()
    30         self.train_tfidf_matrix = self.tfidf_transformer.fit_transform(self.train_count_matrix)
    31         end_time = time.time()
    32         print 'Classifier finished getting matrices, time-consuming:' + str(end_time - start_time) + 's.'
    33 
    34     def classify(self):
    35         self.getMatrices()
    36         start_time = time.time()
    37         # convert a collection of text documents to a matrix of token counts
    38         # scikit-learn doesn't support chinese vocabulary
    39         test_tfidf_vectorizer = CountVectorizer(vocabulary = self.tfidf_vectorizer.vocabulary_)
    40         # learn the vocabulary dictionary and return term-document matrix.
    41         test_count_matrix = test_tfidf_vectorizer.fit_transform(self.test_division)
    42         # transform a count matrix to a normalized tf or tf-idf representation
    43         test_tfidf_transformer = TfidfTransformer()
    44         test_tfidf_matrix = test_tfidf_transformer.fit(self.train_count_matrix).transform(test_count_matrix)
    45 
    46         # the multinomial Naive Bayes classifier is suitable for classification with discrete features
    47         # e.g., word counts for text classification).
    48         naive_bayes = MultinomialNB(alpha = 0.65)
    49         naive_bayes.fit(self.train_tfidf_matrix, self.train_class)
    50         prediction = naive_bayes.predict(test_tfidf_matrix)
    51 
    52         # output result to a csv file
    53         index = 0
    54         csv = open('result.csv', 'w')
    55         for sms_type in prediction:
    56             csv.write(str(self.test[index][0]) + ',' + str(sms_type) + '
    ')
    57             index += 1
    58         csv.close()
    59         end_time = time.time()
    60         print 'Classifier finished classifying, time-consuming: ' + str(end_time - start_time) + 's.'

      Main.py:

     1 # -*- coding:utf-8 -*-
     2 __author__ = 'Jz'
     3 
     4 import time
     5 from Classifier import Classifier
     6 
     7 start_time = time.time()
     8 classifier = Classifier()
     9 classifier.classify()
    10 end_time = time.time()
    11 print 'total time-consuming: ' + str(end_time - start_time) + 's.'
  • 相关阅读:
    云原生生态周报 Vol. 16 | CNCF 归档 rkt,容器运行时“上古”之战老兵凋零
    Knative 基本功能深入剖析:Knative Eventing 之 Sequence 介绍
    基于 K8s 做应用发布的工具那么多, 阿里为啥选择灰姑娘般的 Tekton ?
    Serverless 的喧哗与骚动(一)附Serverless行业发展回顾
    239. Sliding Window Maximum
    237. Delete Node in a Linked List
    146. LRU Cache
    140. Word Break II
    165. Compare Version Numbers
    258. Add Digits
  • 原文地址:https://www.cnblogs.com/jzincnblogs/p/4975109.html
Copyright © 2020-2023  润新知