最近参加比赛,写了一个垃圾短信识别器,在这里做一下记录。
官方提供的数据是csv文件,其中训练集有80万条数据,测试集有20万条数据,训练集的格式为:行号 标记(0为普通短信,1为垃圾短信) 短信内容;测试集的格式为: 行号 短信内容;要求输出的数据格式要求为: 行号 标记,以csv格式保存。
实现的原理可概括为以下几步:
1.读取文件,输入数据
2.对数据进行分割,将每一行数据分成行号、标记、短信内容。由于短信内容中可能存在空格,故不能简单地用split()分割字符串,应该用正则表达式模块re进行匹配分割。
3.将分割结果存入数据库(MySQL),方便下次测试时直接从数据库读取结果,省略步骤。
4.对短信内容进行分词,这一步用到了第三方库结巴分词:https://github.com/fxsjy/jieba
5.将分词的结果用于训练模型,训练的算法为朴素贝叶斯算法,可调用第三方库Scikit-Learn:http://scikit-learn.org/stable
6.从数据库中读取测试集,进行判断,输出结果并写入文件。
最终实现出来一共有4个py文件:
1.ImportIntoDB.py 将数据进行预处理并导入数据库,仅在第一次使用。
2.DataHandler.py 从数据库中读取数据,进行分词,随后处理数据,训练模型。
3.Classifier.py 从数据库中读取测试集数据,利用训练好的模型进行判断,输出结果到文件中。
4.Main.py 程序的入口
最终程序每次运行耗时平均在260秒-270秒之间,附代码:
ImportIntoDB.py:
1 # -*- coding:utf-8 -*- 2 __author__ = 'Jz' 3 4 import MySQLdb 5 import codecs 6 import re 7 import time 8 9 # txt_path = 'D:/coding_file/python_file/Big Data/trash message/train80w.txt' 10 txt_path = 'D:/coding_file/python_file/Big Data/trash message/test20w.txt' 11 12 # use regular expression to split string into parts 13 # split_pattern_80w = re.compile(u'([0-9]+).*?([01])(.*)') 14 split_pattern_20w = re.compile(u'([0-9]+)(.*)') 15 16 txt = codecs.open(txt_path, 'r') 17 lines = txt.readlines() 18 start_time = time.time() 19 20 #connect mysql database 21 con = MySQLdb.connect(host = 'localhost', port = 3306, user = 'root', passwd = '*****', db = 'TrashMessage', charset = 'UTF8') 22 cur = con.cursor() 23 24 # insert into 'train' table 25 # sql = 'insert into train(sms_id, sms_type, content) values (%s, %s, %s)' 26 # for line in lines: 27 # match = re.match(split_pattern_80w, line) 28 # sms_id, sms_type, content = match.group(1), match.group(2), match.group(3).lstrip() 29 # cur.execute(sql, (sms_id, sms_type, content)) 30 # print sms_id 31 # # commit transaction 32 # con.commit() 33 34 # insert into 'test' table 35 sql = 'insert into test(sms_id, content) values (%s, %s)' 36 for line in lines: 37 match = re.match(split_pattern_20w, line) 38 sms_id, content = match.group(1), match.group(2).lstrip() 39 cur.execute(sql, (sms_id, content)) 40 print sms_id 41 # commit transaction 42 con.commit() 43 44 cur.close() 45 con.close() 46 txt.close() 47 end_time = time.time() 48 print 'time-consuming: ' + str(end_time - start_time) + 's.'
DataHandler.py:
1 # -*- coding:utf-8 -*- 2 __author__ = 'Jz' 3 4 import MySQLdb 5 import jieba 6 import re 7 8 class DataHandler: 9 def __init__(self): 10 try: 11 self.con = MySQLdb.connect(host = 'localhost', port = 3306, user = 'root', passwd = '*****', db = 'TrashMessage', charset = 'UTF8') 12 self.cur = self.con.cursor() 13 except MySQLdb.OperationalError, oe: 14 print 'Connection error! Details:', oe 15 16 def __del__(self): 17 self.cur.close() 18 self.con.close() 19 20 # obsolete function 21 # def getConnection(self): 22 # return self.con 23 24 # obsolete function 25 # def getCursor(self): 26 # return self.cur 27 28 def query(self, sql): 29 self.cur.execute(sql) 30 result_set = self.cur.fetchall() 31 return result_set 32 33 def resultSetTransformer(self, train, test): 34 # list of words divided by jieba module after de-duplication 35 train_division = [] 36 test_division = [] 37 # list of classification of each message 38 train_class = [] 39 # divide messages into words 40 for record in train: 41 train_class.append(record[1]) 42 division = jieba.cut(record[2]) 43 filtered_division_set = set() 44 for word in division: 45 filtered_division_set.add(word + ' ') 46 division = list(filtered_division_set) 47 str_word = ''.join(division) 48 train_division.append(str_word) 49 50 # handle test set in a similar way as above 51 for record in test: 52 division = jieba.cut(record[1]) 53 filtered_division_set = set() 54 for word in division: 55 filtered_division_set.add(word + ' ') 56 division = list(filtered_division_set) 57 str_word = ''.join(division) 58 test_division.append(str_word) 59 60 return train_division, train_class, test_division
Classifier.py:
1 # -*- coding:utf-8 -*- 2 __author__ = 'Jz' 3 4 from DataHandler import DataHandler 5 from sklearn.feature_extraction.text import TfidfVectorizer 6 from sklearn.feature_extraction.text import TfidfTransformer 7 from sklearn.feature_extraction.text import CountVectorizer 8 from sklearn.naive_bayes import MultinomialNB 9 import time 10 11 class Classifier: 12 def __init__(self): 13 start_time = time.time() 14 self.data_handler = DataHandler() 15 # get result set 16 self.train = self.data_handler.query('select * from train') 17 self.test = self.data_handler.query('select * from test') 18 self.train_division, self.train_class, self.test_division = self.data_handler.resultSetTransformer(self.train, self.test) 19 end_time = time.time() 20 print 'Classifier finished initializing, time-consuming:' + str(end_time - start_time) + 's.' 21 22 def getMatrices(self): 23 start_time = time.time() 24 # convert a collection of raw documents to a matrix of TF-IDF features. 25 self.tfidf_vectorizer = TfidfVectorizer() 26 # learn vocabulary and idf, return term-document matrix [sample, feature] 27 self.train_count_matrix = self.tfidf_vectorizer.fit_transform(self.train_division) 28 # transform the count matrix of the train set to a normalized tf-idf representation 29 self.tfidf_transformer = TfidfTransformer() 30 self.train_tfidf_matrix = self.tfidf_transformer.fit_transform(self.train_count_matrix) 31 end_time = time.time() 32 print 'Classifier finished getting matrices, time-consuming:' + str(end_time - start_time) + 's.' 33 34 def classify(self): 35 self.getMatrices() 36 start_time = time.time() 37 # convert a collection of text documents to a matrix of token counts 38 # scikit-learn doesn't support chinese vocabulary 39 test_tfidf_vectorizer = CountVectorizer(vocabulary = self.tfidf_vectorizer.vocabulary_) 40 # learn the vocabulary dictionary and return term-document matrix. 41 test_count_matrix = test_tfidf_vectorizer.fit_transform(self.test_division) 42 # transform a count matrix to a normalized tf or tf-idf representation 43 test_tfidf_transformer = TfidfTransformer() 44 test_tfidf_matrix = test_tfidf_transformer.fit(self.train_count_matrix).transform(test_count_matrix) 45 46 # the multinomial Naive Bayes classifier is suitable for classification with discrete features 47 # e.g., word counts for text classification). 48 naive_bayes = MultinomialNB(alpha = 0.65) 49 naive_bayes.fit(self.train_tfidf_matrix, self.train_class) 50 prediction = naive_bayes.predict(test_tfidf_matrix) 51 52 # output result to a csv file 53 index = 0 54 csv = open('result.csv', 'w') 55 for sms_type in prediction: 56 csv.write(str(self.test[index][0]) + ',' + str(sms_type) + ' ') 57 index += 1 58 csv.close() 59 end_time = time.time() 60 print 'Classifier finished classifying, time-consuming: ' + str(end_time - start_time) + 's.'
Main.py:
1 # -*- coding:utf-8 -*- 2 __author__ = 'Jz' 3 4 import time 5 from Classifier import Classifier 6 7 start_time = time.time() 8 classifier = Classifier() 9 classifier.classify() 10 end_time = time.time() 11 print 'total time-consuming: ' + str(end_time - start_time) + 's.'