代码已上传到github:https://github.com/taishan1994/tensorflow-text-classification
往期精彩:
利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料)
利用transformer进行中文文本分类(数据集是复旦中文语料)
基于tensorflow的中文文本分类
数据集:复旦中文语料,包含20类
数据集下载地址:https://www.kesci.com/mw/dataset/5d3a9c86cf76a600360edd04/content
数据集下载好之后将其放置在data文件夹下;
修改globalConfig.py中的全局路径为自己项目的路径;
处理后的数据和已训练好保存的模型,在这里可以下载:
链接:https://pan.baidu.com/s/1ZHzO5e__-WFYAYFIt2Kmsg 提取码:vvzy
目录结构:
|--checkpint:保存模型目录
|--|--transformer:transformer模型保存位置;
|--config:配置文件;
|--|--fudanConfig.py:包含训练配置、模型配置、数据集配置;
|--|--globaConfig.py:全局配置文件,主要是全局路径、全局参数等;
|-- data:数据保存位置;
|--|--|--Fudan:复旦数据;
|--|--|--train:训练数据;
|--|--|--answer:测试数据;
|--dataset:创建数据集,对数据进行处理的一些操作;
|--images:结果可视化图片保存位置;
|--models:模型保存文件;
|--process:对原始数据进行处理后的数据;
|--tensorboard:tensorboard可视化文件保存位置,暂时未用到;
|--utils:辅助函数保存位置,包括word2vec训练词向量、评价指标计算、结果可视化等;
|--main.py:主运行文件,选择模型、训练、测试和预测;
初始配置:
- 词嵌入维度:200
- 学习率:0.001
- epoch:50
- 词汇表大小:6000+2(加2是PAD和UNK)
- 文本最大长度:600
- 每多少个step进行验证:100
- 每多少个step进行存储模型:100
环境:
- python=>=3.6
- tensorflow==1.15.0
当前支持的模型:
- bilstm
- bilstm+attention
- textcnn
- rcnn
- transformer
说明
数据的输入格式:
(1)分词后去除掉停止词,再对词语进行词频统计,取频数最高的前6000个词语作为词汇表;
(2)像词汇表中加入PAD和UNK,实际上的词汇表的词语总数为6000+2=6002;
(3)当句子长度大于指定的最大长度,进行裁剪,小于最大长度,在句子前面用PAD进行填充;
(4)如果句子中的词语在词汇表中没有出现则用UNK进行代替;
(5)输入到网络中的句子实际上是进行分词后的词语映射的id,比如:
(6)输入的标签是要经过onehot编码的;
"""
"我喜欢上海",
"我喜欢打羽毛球",
"""
词汇表:['我','喜欢','打','上海','羽毛球'],对应映射:[2,3,4,5,6],0对应PAD,1对应UNK
得到:
[
[0,2,3,5],
[2,3,4,6],
]
python main.py --model transformer --saver_dir checkpoint/transformer --save_png images/transformer --train --test --predict
参数说明:
- --model:选择模型,可选[transformer、bilstm、bilstmattn、textcnn、rcnn]
- --saver_dir:模型保存位置,一般是checkpoint+模型名称
- --save_png:结果可视化保存位置,一般是images+模型名称
- --train:是否进行训练,默认为False
- --test:是否进行测试,默认为False
- --predict:是否进行预测,默认为False
结果
以transformer为例:
部分训练结果:
2020-11-01T10:43:16.955322, step: 1300, loss: 5.089711, acc: 0.8546,precision: 0.3990, recall: 0.4061, f_beta: 0.3977 *
Epoch: 83
train: step: 1320, loss: 0.023474, acc: 0.9922, recall: 0.8444, precision: 0.8474, f_beta: 0.8457
Epoch: 84
train: step: 1340, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500
Epoch: 85
train: step: 1360, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500
Epoch: 86
Epoch: 87
train: step: 1380, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500
Epoch: 88
train: step: 1400, loss: 0.000000, acc: 1.0000, recall: 0.7000, precision: 0.7000, f_beta: 0.7000
开始验证。。。
2020-11-01T10:44:07.347359, step: 1400, loss: 5.111372, acc: 0.8506,precision: 0.4032, recall: 0.4083, f_beta: 0.3982 *
Epoch: 89
train: step: 1420, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500
Epoch: 90
train: step: 1440, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500
Epoch: 91
Epoch: 92
train: step: 1460, loss: 0.000000, acc: 1.0000, recall: 0.7000, precision: 0.7000, f_beta: 0.7000
Epoch: 93
train: step: 1480, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500
Epoch: 94
train: step: 1500, loss: 0.000000, acc: 1.0000, recall: 0.6000, precision: 0.6000, f_beta: 0.6000
开始验证。。。
2020-11-01T10:44:57.645305, step: 1500, loss: 5.206666, acc: 0.8521,precision: 0.4003, recall: 0.4040, f_beta: 0.3957
Epoch: 95
train: step: 1520, loss: 0.000000, acc: 1.0000, recall: 0.6000, precision: 0.6000, f_beta: 0.6000
Epoch: 96
Epoch: 97
train: step: 1540, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: