• 【文本分类-中文】textCNN


    目录

    1. 概述
    2. 数据集合
    3. 代码
    4. 结果展示

    一、概述

    在英文分类的基础上,再看看中文分类的,是一种10分类问题(体育,科技,游戏,财经,房产,家居等)的处理。

    二、数据集合

    数据集为新闻,总共有四个数据文件,在/data/cnews目录下,包括内容如下图所示测试集,训练集和验证集,和单词表(最后的单词表cnews.vocab.txt可以不要,因为训练可以自动产生)。数据格式:前面为类别,后面为描述内容。

    训练数据地址:链接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取码:2r04

    其中训练集的格式:

    vocab.txt的格式:每个字一行,其中前面加上PAD。

    三、代码

    3.1 数据采集cnews_loader.py

        1     # coding: utf-8
        2     import sys
        3     from collections import Counter
        4     import numpy as np
        5     import tensorflow.contrib.keras as kr
        6     
        7     if sys.version_info[0] > 2:
        8         is_py3 = True
        9     else:
       10         reload(sys)
       11         sys.setdefaultencoding("utf-8")
       12         is_py3 = False
       13     
       14     def native_word(word, encoding='utf-8'):
       15         """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
       16         if not is_py3:
       17             return word.encode(encoding)
       18         else:
       19             return word
       20     
       21     def native_content(content):
       22         if not is_py3:
       23             return content.decode('utf-8')
       24         else:
       25             return content
       26     
       27     def open_file(filename, mode='r'):
       28         """
       29         常用文件操作,可在python2和python3间切换.
       30         mode: 'r' or 'w' for read or write
       31         """
       32         if is_py3:
       33             return open(filename, mode, encoding='utf-8', errors='ignore')
       34         else:
       35             return open(filename, mode)
       36     
       37     def read_file(filename):
       38         """读取文件数据"""
       39         contents, labels = [], []
       40         with open_file(filename) as f:
       41             for line in f:
       42                 try:
       43                     label, content = line.strip().split('	')
       44                     if content:
       45                         contents.append(list(native_content(content)))
       46                         labels.append(native_content(label))
       47                 except:
       48                     pass
       49         return contents, labels
       50     
       51     def build_vocab(train_dir, vocab_dir, vocab_size=5000):
       52         """根据训练集构建词汇表,存储"""
       53         data_train, _ = read_file(train_dir)
       54         all_data = []
       55         for content in data_train:
       56             all_data.extend(content)
       57         counter = Counter(all_data)
       58         count_pairs = counter.most_common(vocab_size - 1)
       59         words, _ = list(zip(*count_pairs))
       60         # 添加一个 <PAD> 来将所有文本pad为同一长度
       61         words = ['<PAD>'] + list(words)
       62         open_file(vocab_dir, mode='w').write('
    '.join(words) + '
    ')
       63     
       64     def read_vocab(vocab_dir):
       65         """读取词汇表"""
       66         # words = open_file(vocab_dir).read().strip().split('
    ')
       67         with open_file(vocab_dir) as fp:
       68             # 如果是py2 则每个值都转化为unicode
       69             words = [native_content(_.strip()) for _ in fp.readlines()]
       70         word_to_id = dict(zip(words, range(len(words))))
       71         return words, word_to_id
       72     
       73     def read_category():
       74         """读取分类目录,固定"""
       75         categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
       76         categories = [native_content(x) for x in categories]
       77         cat_to_id = dict(zip(categories, range(len(categories))))
       78         return categories, cat_to_id
       79     
       80     def to_words(content, words):
       81         """将id表示的内容转换为文字"""
       82         return ''.join(words[x] for x in content)
       83     
       84     def process_file(filename, word_to_id, cat_to_id, max_length=600):
       85         """将文件转换为id表示"""
       86         contents, labels = read_file(filename)
       87     
       88         data_id, label_id = [], []
       89         for i in range(len(contents)):
       90             data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
       91             label_id.append(cat_to_id[labels[i]])
       92         # 使用keras提供的pad_sequences来将文本pad为固定长度
       93         x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
       94         y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示
       95         return x_pad, y_pad
       96     
       97     def batch_iter(x, y, batch_size=64):
       98         """生成批次数据"""
       99         data_len = len(x)
      100         num_batch = int((data_len - 1) / batch_size) + 1
      101         indices = np.random.permutation(np.arange(data_len))
      102         x_shuffle = x[indices]
      103         y_shuffle = y[indices]
      104     
      105         for i in range(num_batch):
      106             start_id = i * batch_size
      107             end_id = min((i + 1) * batch_size, data_len)
      108             yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

    3.2 模型搭建cnn_model.py

    定义训练的参数,TextCNN()模型

        1     # coding: utf-8
        2     import tensorflow as tf
        3     class TCNNConfig(object):
        4         """CNN配置参数"""
        5         embedding_dim = 64  # 词向量维度
        6         seq_length = 600  # 序列长度
        7         num_classes = 10  # 类别数
        8         num_filters = 256  # 卷积核数目
        9         kernel_size = 5  # 卷积核尺寸
       10         vocab_size = 5000  # 词汇表达小
       11         hidden_dim = 128  # 全连接层神经元
       12         dropout_keep_prob = 0.5  # dropout保留比例
       13         learning_rate = 1e-3  # 学习率
       14         batch_size = 64  # 每批训练大小
       15         num_epochs = 10  # 总迭代轮次
       16         print_per_batch = 100  # 每多少轮输出一次结果
       17         save_per_batch = 10  # 每多少轮存入tensorboard
       18     
       19     class TextCNN(object):
       20         """文本分类,CNN模型"""
       21         def __init__(self, config):
       22             self.config = config
       23             # 三个待输入的数据
       24             self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
       25             self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
       26             self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
       27             self.cnn()
       28     
       29         def cnn(self):
       30             """CNN模型"""
       31             # 词向量映射
       32             with tf.device('/cpu:0'):
       33                 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
       34                 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
       35     
       36             with tf.name_scope("cnn"):
       37                 # CNN layer
       38                 conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
       39                 # global max pooling layer
       40                 gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
       41     
       42             with tf.name_scope("score"):
       43                 # 全连接层,后面接dropout以及relu激活
       44                 fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
       45                 fc = tf.contrib.layers.dropout(fc, self.keep_prob)
       46                 fc = tf.nn.relu(fc)
       47                 # 分类器
       48                 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
       49                 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别
       50     
       51             with tf.name_scope("optimize"):
       52                 # 损失函数,交叉熵
       53                 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
       54                 self.loss = tf.reduce_mean(cross_entropy)
       55                 # 优化器
       56                 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
       57     
       58             with tf.name_scope("accuracy"):
       59                 # 准确率
       60                 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
       61                 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    3.3 运行代码run_cnn.py

      1 #!/usr/bin/python
      2 # -*- coding: utf-8 -*-
      3 from __future__ import print_function
      4 import os
      5 import sys
      6 import time
      7 from datetime import timedelta
      8 import numpy as np
      9 import tensorflow as tf
     10 from sklearn import metrics
     11 from cnn_model import TCNNConfig, TextCNN
     12 from  cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
     13 
     14 base_dir = '../data/cnews'
     15 train_dir = os.path.join(base_dir, 'cnews.train.txt')
     16 test_dir = os.path.join(base_dir, 'cnews.test.txt')
     17 val_dir = os.path.join(base_dir, 'cnews.val.txt')
     18 vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
     19 save_dir = 'checkpoints/textcnn'
     20 save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
     21 
     22 def get_time_dif(start_time):
     23     """获取已使用时间"""
     24     end_time = time.time()
     25     time_dif = end_time - start_time
     26     return timedelta(seconds=int(round(time_dif)))
     27 
     28 def feed_data(x_batch, y_batch, keep_prob):
     29     feed_dict = {
     30         model.input_x: x_batch,
     31         model.input_y: y_batch,
     32         model.keep_prob: keep_prob
     33     }
     34     return feed_dict
     35 
     36 def evaluate(sess, x_, y_):
     37     """评估在某一数据上的准确率和损失"""
     38     data_len = len(x_)
     39     batch_eval = batch_iter(x_, y_, 128)
     40     total_loss = 0.0
     41     total_acc = 0.0
     42     for x_batch, y_batch in batch_eval:
     43         batch_len = len(x_batch)
     44         feed_dict = feed_data(x_batch, y_batch, 1.0)
     45         loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
     46         total_loss += loss * batch_len
     47         total_acc += acc * batch_len
     48     return total_loss / data_len, total_acc / data_len
     49 
     50 def train():
     51     print("Configuring TensorBoard and Saver...")
     52     # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
     53     tensorboard_dir = '../tensorboard/textcnn'
     54     if not os.path.exists(tensorboard_dir):
     55         os.makedirs(tensorboard_dir)
     56     tf.summary.scalar("loss", model.loss)
     57     tf.summary.scalar("accuracy", model.acc)
     58     merged_summary = tf.summary.merge_all()
     59     writer = tf.summary.FileWriter(tensorboard_dir)
     60 
     61     # 配置 Saver
     62     saver = tf.train.Saver()
     63     if not os.path.exists(save_dir):
     64         os.makedirs(save_dir)
     65 
     66     print("Loading training and validation data...")
     67     # 载入训练集与验证集
     68     start_time = time.time()
     69     x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
     70     x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
     71     time_dif = get_time_dif(start_time)
     72     print("Time usage:", time_dif)
     73 
     74     # 创建session
     75     session = tf.Session()
     76     session.run(tf.global_variables_initializer())
     77     writer.add_graph(session.graph)
     78 
     79     print('Training and evaluating...')
     80     start_time = time.time()
     81     total_batch = 0  # 总批次
     82     best_acc_val = 0.0  # 最佳验证集准确率
     83     last_improved = 0  # 记录上一次提升批次
     84     require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练
     85 
     86     flag = False
     87     for epoch in range(config.num_epochs):
     88         print('Epoch:', epoch + 1)
     89         batch_train = batch_iter(x_train, y_train, config.batch_size)
     90         for x_batch, y_batch in batch_train:
     91             feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
     92             #print("x_batch is {}".format(x_batch.shape))
     93             if total_batch % config.save_per_batch == 0:
     94                 # 每多少轮次将训练结果写入tensorboard scalar
     95                 s = session.run(merged_summary, feed_dict=feed_dict)
     96                 writer.add_summary(s, total_batch)
     97             if total_batch % config.print_per_batch == 0:
     98                 # 每多少轮次输出在训练集和验证集上的性能
     99                 feed_dict[model.keep_prob] = 1.0
    100                 loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
    101                 loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
    102                 if acc_val > best_acc_val:
    103                     # 保存最好结果
    104                     best_acc_val = acc_val
    105                     last_improved = total_batch
    106                     saver.save(sess=session, save_path=save_path)
    107                     improved_str = '*'
    108                 else:
    109                     improved_str = ''
    110                 time_dif = get_time_dif(start_time)
    111                 msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' 
    112                       + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
    113                 print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
    114 
    115             session.run(model.optim, feed_dict=feed_dict)  # 运行优化
    116             total_batch += 1
    117 
    118             if total_batch - last_improved > require_improvement:
    119                 # 验证集正确率长期不提升,提前结束训练
    120                 print("No optimization for a long time, auto-stopping...")
    121                 flag = True
    122                 break  # 跳出循环
    123         if flag:  # 同上
    124             break
    125 
    126 def test():
    127     print("Loading test data...")
    128     start_time = time.time()
    129     x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
    130 
    131     session = tf.Session()
    132     session.run(tf.global_variables_initializer())
    133     saver = tf.train.Saver()
    134     saver.restore(sess=session, save_path=save_path)  # 读取保存的模型
    135 
    136     print('Testing...')
    137     loss_test, acc_test = evaluate(session, x_test, y_test)
    138     msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    139     print(msg.format(loss_test, acc_test))
    140 
    141     batch_size = 128
    142     data_len = len(x_test)
    143     num_batch = int((data_len - 1) / batch_size) + 1
    144 
    145     y_test_cls = np.argmax(y_test, 1)
    146     y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
    147     for i in range(num_batch):  # 逐批次处理
    148         start_id = i * batch_size
    149         end_id = min((i + 1) * batch_size, data_len)
    150         feed_dict = {
    151             model.input_x: x_test[start_id:end_id],
    152             model.keep_prob: 1.0
    153         }
    154         y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
    155 
    156     # 评估
    157     print("Precision, Recall and F1-Score...")
    158     print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
    159 
    160     # 混淆矩阵
    161     print("Confusion Matrix...")
    162     cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    163     print(cm)
    164 
    165     time_dif = get_time_dif(start_time)
    166     print("Time usage:" , time_dif)
    167 
    168 if __name__ == '__main__':
    169     
    170     config = TCNNConfig()
    171     if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建,这里存在,因此不用重建
    172         build_vocab(train_dir, vocab_dir, config.vocab_size)
    173     categories, cat_to_id = read_category()
    174     words, word_to_id = read_vocab(vocab_dir)
    175     config.vocab_size = len(words)
    176     model = TextCNN(config)
    177     option='train'
    178     if option == 'train':
    179         train()
    180     else:
    181         test()

    3.4 预测predict.py

        1     # coding: utf-8
        2     from __future__ import print_function
        3     import os
        4     import tensorflow as tf
        5     import tensorflow.contrib.keras as kr
        6     from cnn_model import TCNNConfig, TextCNN
        7     from cnews_loader import read_category, read_vocab
        8     try:
        9         bool(type(unicode))
       10     except NameError:
       11         unicode = str
       12     
       13     base_dir = '../data/cnews'
       14     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
       15     save_dir = '../checkpoints/textcnn'
       16     save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
       17     
       18     class CnnModel:
       19         def __init__(self):
       20             self.config = TCNNConfig()
       21             self.categories, self.cat_to_id = read_category()
       22             self.words, self.word_to_id = read_vocab(vocab_dir)
       23             self.config.vocab_size = len(self.words)
       24             self.model = TextCNN(self.config)
       25             self.session = tf.Session()
       26             self.session.run(tf.global_variables_initializer())
       27             saver = tf.train.Saver()
       28             saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
       29     
       30         def predict(self, message):
       31             # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
       32             content = unicode(message)
       33             data = [self.word_to_id[x] for x in content if x in self.word_to_id]
       34     
       35             feed_dict = {
       36                 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
       37                 self.model.keep_prob: 1.0
       38             }
       39     
       40             y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
       41             return self.categories[y_pred_cls[0]]
       42     
       43     if __name__ == '__main__':
       44         cnn_model = CnnModel()
       45         test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机',
       46                      '热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']
       47         for i in test_demo:
       48             print(cnn_model.predict(i))

    四、结果展示

       

       

    相关代码可见:https://github.com/yifanhunter/textClassifier_chinese

       

  • 相关阅读:
    蓝牙的AVDTP协议笔记
    蓝牙的AVCTP协议笔记
    hosts学习整理
    Win10报错0x800f0906
    Git Bash的妙用
    将xml文件由格式化变为压缩字符串
    try-with-resource机制的一个编译陷阱
    Git回滚代码暴力法
    IDEA中Git分支未push的变更集如何合并到另一个分支
    日期类型存储成字符串类型的格式问题
  • 原文地址:https://www.cnblogs.com/yifanrensheng/p/13583443.html
Copyright © 2020-2023  润新知