• NLP(二十九):BertForSequenceClassification的新闻标题分类,基于pytorch_pretrained_bert


    背景

    BERT的问世向世人宣告了无监督预训练的语言模型在众多NLP任务中成为“巨人肩膀”的可能性,接踵而出的GPT2、XL-Net则不断将NLP从业者的期望带向了新的高度。得益于这些力作模型的开源,使得我们在了解其论文思想的基础上,可以借力其凭借强大算力预训练的模型从而快速在自己的数据集上开展实验,甚至应用于真实的业务中。

    在GitHub上已经存在使用多种语言/框架依照Google最初release的TensorFlow版本的代码进行实现的Pretrained-BERT,并且都提供了较为详细的文档。本文主要展示通过极简的代码调用Pytorch Pretrained-BERT并进行fine-tuning的文本分类任务。

    下面的代码是使用pytorch-pretrained-BERT进行文本分类的官方实现,感兴趣的同学可以直接点进去阅读:

    https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py​github.com

    数据介绍

    本文所使用的数据是标题及其对应的类别,如“中国的垃圾分类能走多远”对应“社会”类别,共有28个类别,每个类别的训练数据和测试数据各有1000条,数据已经同步至云盘,欢迎下载。链接:

    https://pan.baidu.com/s/1r4SI6-IizlCcsyMGL7RU8Q​pan.baidu.com

    提取码: 6awx

    加载库

    import os
    import sys
    import pickle
    import pandas as pd
    import numpy as np
    from concurrent.futures import ThreadPoolExecutor
    import torch
    import pickle
    from sklearn.preprocessing import LabelEncoder
    from torch.optim import optimizer
    from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
    from torch.nn import CrossEntropyLoss,BCEWithLogitsLoss
    from tqdm import tqdm_notebook, trange
    from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification
    from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
    from sklearn.metrics import precision_recall_curve,classification_report
    import matplotlib.pyplot as plt
    %matplotlib inline
    

    加载数据

    # pandas读取数据
    data = pd.read_pickle("title_category.pkl")
    # 列名重新命名
    data.columns = ['text','label']

    标签编码

    因为label为中文格式,为了适应模型的输入需要进行ID化,此处调用sklearn中的label encoder方法快速进行变换。

    le = LabelEncoder()
    le.fit(data.label.tolist())
    data['label'] = le.transform(data.label.tolist())

    观察数据

    训练数据准备

    本文需要使用的预训练bert模型为使用中文维基语料训练的字符级别的模型,在Google提供的模型列表中对应的名称为'bert-base-chinese',使用更多语言语料训练的模型名称可以参见下方链接:

    另外,首次执行下面的代码时因为本地没有cache,因此会自动启动下载,实践证明下载速度还是很快的。需要注意的是,do_lower_case参数需要手动显式的设置为False

    # 分词工具
    bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', do_lower_case=False)
    # 封装类
    class DataPrecessForSingleSentence(object):
        """
        对文本进行处理
        """
    def __init__(self, bert_tokenizer, max_workers=10):
            """
            bert_tokenizer :分词器
            dataset        :包含列名为'text'与'label'的pandas dataframe
            """
            self.bert_tokenizer = bert_tokenizer
            # 创建多线程池
            self.pool = ThreadPoolExecutor(max_workers=max_workers)
            # 获取文本与标签
    def get_input(self, dataset, max_seq_len=30):
            """
            通过多线程(因为notebook中多进程使用存在一些问题)的方式对输入文本进行分词、ID化、截断、填充等流程得到最终的可用于模型输入的序列。
            
            入参:
                dataset     : pandas的dataframe格式,包含两列,第一列为文本,第二列为标签。标签取值为{0,1},其中0表示负样本,1代表正样本。
                max_seq_len : 目标序列长度,该值需要预先对文本长度进行分别得到,可以设置为小于等于512(BERT的最长文本序列长度为512)的整数。
            
            出参:
                seq         : 在入参seq的头尾分别拼接了'CLS'与'SEP'符号,如果长度仍小于max_seq_len,则使用0在尾部进行了填充。
                seq_mask    : 只包含0、1且长度等于seq的序列,用于表征seq中的符号是否是有意义的,如果seq序列对应位上为填充符号,
                              那么取值为1,否则为0。
                seq_segment : shape等于seq,因为是单句,所以取值都为0。
                labels      : 标签取值为{0,1},其中0表示负样本,1代表正样本。
            
                
            """
            sentences = dataset.iloc[:, 0].tolist()
            labels = dataset.iloc[:, 1].tolist()
            # 切词
            tokens_seq = list(
                self.pool.map(self.bert_tokenizer.tokenize, sentences))
            # 获取定长序列及其mask
            result = list(
                self.pool.map(self.trunate_and_pad, tokens_seq,
                              [max_seq_len] * len(tokens_seq)))
            seqs = [i[0] for i in result]
            seq_masks = [i[1] for i in result]
            seq_segments = [i[2] for i in result]
            return seqs, seq_masks, seq_segments, labels
    def trunate_and_pad(self, seq, max_seq_len):
            """
            1. 因为本类处理的是单句序列,按照BERT中的序列处理方式,需要在输入序列头尾分别拼接特殊字符'CLS'与'SEP',
               因此不包含两个特殊字符的序列长度应该小于等于max_seq_len-2,如果序列长度大于该值需要那么进行截断。
            2. 对输入的序列 最终形成['CLS',seq,'SEP']的序列,该序列的长度如果小于max_seq_len,那么使用0进行填充。
            
            入参: 
                seq         : 输入序列,在本处其为单个句子。
                max_seq_len : 拼接'CLS'与'SEP'这两个特殊字符后的序列长度
            
            出参:
                seq         : 在入参seq的头尾分别拼接了'CLS'与'SEP'符号,如果长度仍小于max_seq_len,则使用0在尾部进行了填充。
                seq_mask    : 只包含0、1且长度等于seq的序列,用于表征seq中的符号是否是有意义的,如果seq序列对应位上为填充符号,
                              那么取值为1,否则为0。
                seq_segment : shape等于seq,因为是单句,所以取值都为0。
               
            """
            # 对超长序列进行截断
            if len(seq) > (max_seq_len - 2):
                seq = seq[0:(max_seq_len - 2)]
            # 分别在首尾拼接特殊符号
            seq = ['[CLS]'] + seq + ['[SEP]']
            # ID化
            seq = self.bert_tokenizer.convert_tokens_to_ids(seq)
            # 根据max_seq_len与seq的长度产生填充序列
            padding = [0] * (max_seq_len - len(seq))
            # 创建seq_mask
            seq_mask = [1] * len(seq) + padding
            # 创建seq_segment
            seq_segment = [0] * len(seq) + padding
            # 对seq拼接填充序列
            seq += padding
            assert len(seq) == max_seq_len
            assert len(seq_mask) == max_seq_len
            assert len(seq_segment) == max_seq_len
            return seq, seq_mask, seq_segment
    

    DataPrecessForSingleSentence是一个用于将pandas Dataframe转化为模型输入的类,每个函数的入参和出参已经写得比较清晰翔实了。处理流程大致如下:

    • 通过多线程的方式进行调用tokenize进行切词(字符级别)
    • 对于切词产生的序列如果长度大于设置的max_seq_len-2时需要进行截断。BERT中使用的max_seq_len是512,因此最长不可以超过512个字符。另外,本处需要减2的原因在于还需要在原始序列上拼接两个特殊符号,因此需要预留两个字符的“槽位”。
    • 在首、尾分别拼接'[CLS]'及'[SEP]',如果序列长度不足max_seq_len,使用0进行填充。产生相应的mask序列和segment序列,其中mask序列使用0、1值标注对应位上是否为填充符号,如果是那么取值为0,负责为1,如果序列长度不足max_seq_len,使用0进行填充。segment序列则用于表示序列是否为同一个输入源,在本例中取值全部为0,如果序列长度不足max_seq_len,使用0进行填充。
    • 对于填充后的序列进行ID化,调用的是convert_tokens_to_ids方法,最终返回seq,seq_mask 与seq_segment序列。
    # 类初始化
    processor = DataPrecessForSingleSentence(bert_tokenizer= bert_tokenizer)
    # 产生输入ju 数据
    seqs, seq_masks, seq_segments, labels = processor.get_input(
        dataset=data, max_seq_len=30)

    本文设定的max_seq_len为30,因为通过统计标题的长度可以得知30已经是其85百分位数,基本已经涵盖了绝大部分样本。

    加载预训练的bert模型

    # 加载预训练的bert模型
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-chinese', num_labels=28)

    同样,首次执行会自动启动下载,在本例中因为有28个类别,因此num_labels参数需要设置为28。

    数据格式化

    数据格式化指的是将list格式的数据转化为torch的tensor格式。

    # 转换为torch tensor
    t_seqs = torch.tensor(seqs, dtype=torch.long)
    t_seq_masks = torch.tensor(seq_masks, dtype = torch.long)
    t_seq_segments = torch.tensor(seq_segments, dtype = torch.long)
    t_labels = torch.tensor(labels, dtype = torch.long)
    train_data = TensorDataset(t_seqs, t_seq_masks, t_seq_segments, t_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloder = DataLoader(dataset= train_data, sampler= train_sampler,batch_size = 256)
    

    使用了TensorDatasetRandomSamplerDataLoader对输入数据进行了封装,相较于自己编写generator代码量简短很多,此处设置的batch size为256。

    # 将模型转换为trin mode
    model.train()
    
    BertForSequenceClassification(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(21128, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): BertLayerNorm()
          (dropout): Dropout(p=0.1)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (1): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (2): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (3): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (4): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (5): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (6): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (7): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (8): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (9): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (10): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
            (11): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): BertLayerNorm()
                  (dropout): Dropout(p=0.1)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (dropout): Dropout(p=0.1)
              )
            )
          )
        )
        (pooler): BertPooler(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (activation): Tanh()
        )
      )
      (dropout): Dropout(p=0.1)
      (classifier): Linear(in_features=768, out_features=28, bias=True)
    )
    

    从打印出的网络结构可以看出,classifier层的out_features已经设置为了上文的提到的28。另外,我们可以关注一下BertPooler层,如果对于前面步骤中在序列头部拼接[CLS]有疑问的话,通过阅读BertPooler的代码可以明晰该字符的用处。

    # link : https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py
    class BertPooler(nn.Module):
        def __init__(self, config):
            super(BertPooler, self).__init__()
            self.dense = nn.Linear(config.hidden_size, config.hidden_size)
            self.activation = nn.Tanh()
    def forward(self, hidden_states):
            # We "pool" the model by simply taking the hidden state corresponding
            # to the first token.
            first_token_tensor = hidden_states[:, 0]
            pooled_output = self.dense(first_token_tensor)
            pooled_output = self.activation(pooled_output)
            return pooled_output
    

    上面的代码是BertPooler的实现,可以看出在forward方法中hidden_states[:, 0]只取了第一个字符对应的hidden unit,因此凭借双向Encoder的表征能力,'[CLS]'符号融合了整个序列的表征信息,因此可以用于以一种低维的方式对整个序列进行表征。

    # 待优化的参数
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params':
            [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay':
            0.01
        },
        {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }
    ]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=2e-05,
                         warmup= 0.1 ,
                         t_total= 2000)
    device = 'cpu'
    

    我记得当时在看《动手学深度学习》一书(3.12节)时,李沐提到权重衰减等价于L2正则化。在bert官方的代码中对于bias项、LayerNorm.biasLayerNorm.weight项免于正则化。

    fine-tuning

    # 存储每一个batch的loss
    loss_collect = []
    for i in trange(10, desc='Epoch'):
        for step, batch_data in enumerate(
                tqdm_notebook(train_dataloder, desc='Iteration')):
            batch_data = tuple(t.to(device) for t in batch_data)
            batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels = batch_data
            # 对标签进行onehot编码
            one_hot = torch.zeros(batch_labels.size(0), 28).long()
            one_hot_batch_labels = one_hot.scatter_(
                dim=1,
                index=torch.unsqueeze(batch_labels, dim=1),
                src=torch.ones(batch_labels.size(0), 28).long())
    logits = model(
                batch_seqs, batch_seq_masks, batch_seq_segments, labels=None)
            logits = logits.softmax(dim=1)
            loss_function = CrossEntropyLoss()
            loss = loss_function(logits, batch_labels)
            loss.backward()
            loss_collect.append(loss.item())
            print("
    %f" % loss, end='')
            optimizer.step()
            optimizer.zero_grad()
    

    总共进行了10个epoch的训练,将各个batch的loss写入了loss_collect,下面对loss_collect进行可视化。

    loss可视化

    plt.figure(figsize=(12,8))
    plt.plot(range(len(loss_collect)), loss_collect,'g.')
    plt.grid(True)
    plt.show()
    

     

    从上图可以看出,loss在前200个batch下降速度明显,随后下降速度逐渐变缓,但从整体趋势以及纵轴的loss绝对值可以看出,loss距离收敛还存在一定空间,如果增大训练样本量及迭代次数,loss依然可以继续减小。

    测试

    模型持久化

    torch.save(model,open("fine_tuned_chinese_bert.bin","wb"))

    加载测试数据

    test_data = pd.read_pickle("title_category_valid.pkl")
    test_data.columns = ['text','label']
    # 标签ID化
    test_data['label'] = le.transform(test_data.label.tolist())
    # 转换为tensor
    test_seqs, test_seq_masks, test_seq_segments, test_labels = processor.get_input(
        dataset=test_data, max_seq_len=30)
    test_seqs = torch.tensor(test_seqs, dtype=torch.long)
    test_seq_masks = torch.tensor(test_seq_masks, dtype = torch.long)
    test_seq_segments = torch.tensor(test_seq_segments, dtype = torch.long)
    test_labels = torch.tensor(test_labels, dtype = torch.long)
    test_data = TensorDataset(test_seqs, test_seq_masks, test_seq_segments, test_labels)
    test_dataloder = DataLoader(dataset= train_data, batch_size = 256)
    # 用于存储预测标签与真实标签
    true_labels = []
    pred_labels = []
    model.eval()
    # 预测
    with torch.no_grad():
        for batch_data in tqdm_notebook(test_dataloder, desc = 'TEST'):
            batch_data = tuple(t.to(device) for t in batch_data)
            batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels = batch_data        
            logits = model(
                batch_seqs, batch_seq_masks, batch_seq_segments, labels=None)
            logits = logits.softmax(dim=1).argmax(dim = 1)
            pred_labels.append(logits.detach().numpy())
            true_labels.append(batch_labels.detach().numpy())
    # 查看各个类别的准召
    print(classification_report(np.concatenate(true_labels), np.concatenate(pred_labels)))       
    

     

                  precision    recall  f1-score   support
    0       0.93      0.95      0.94      1000
               1       0.88      0.90      0.89      1000
               2       0.91      0.92      0.91      1000
               3       0.88      0.95      0.92      1000
               4       0.88      0.92      0.90      1000
               5       0.91      0.91      0.91      1000
               6       0.85      0.84      0.84      1000
               7       0.93      0.97      0.95      1000
               8       0.88      0.94      0.91      1000
               9       0.77      0.86      0.81      1000
              10       0.97      0.94      0.96      1000
              11       0.85      0.90      0.88      1000
              12       0.91      0.97      0.94      1000
              13       0.75      0.86      0.80      1000
              14       0.84      0.90      0.87      1000
              15       0.77      0.87      0.82      1000
              16       0.91      0.95      0.93      1000
              17       0.96      0.95      0.95      1000
              18       0.91      0.93      0.92      1000
              19       0.92      0.94      0.93      1000
              20       0.94      0.93      0.93      1000
              21       0.80      0.80      0.80      1000
              22       0.93      0.97      0.95      1000
              23       0.82      0.86      0.84      1000
              24       0.00      0.00      0.00      1000
              25       0.92      0.93      0.93      1000
              26       0.89      0.90      0.89      1000
              27       0.89      0.89      0.89      1000
    micro avg       0.88      0.88      0.88     28000
       macro avg       0.85      0.88      0.86     28000
    weighted avg       0.85      0.88      0.86     28000
    

    可以看出,整体的准召还是比较理想的,不过因为训练和测试都是使用的平衡数据集,因此在真实分布上的准召与该数据集存在一定差异。

    总结

    本文主要是对run_classifier.py的代码进行了简化,然后在中文数据集上进行了fine-tuning。具体的数据集和代码在文中进行了提供和展示,欢迎交流!

    转自:https://zhuanlan.zhihu.com/p/72448986

  • 相关阅读:
    xiong_6博客迁址
    调用百度地图获取地理位置
    fastadmin 接口开发注意事项
    fastadmin下拉选择框数据生成
    fastadmin 模型篇
    fastadmin跨数据库配置模型
    Git学习笔记#2-创建版本库与提交文件
    Git学习笔记#1-基本概念
    科目一知识点汇总
    Mysql学习笔记#6-约束
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15067047.html
Copyright © 2020-2023  润新知