• Bert源码解读(三)之预训练部分


    一、Masked LM

    get_masked_lm_output函数用于计算「任务#1」的训练 loss。输入为 BertModel 的最后一层 sequence_output 输出([batch_size, seq_length, hidden_size]),先找出输出结果中masked掉的词,然后构建一层全连接网络,接着构建一层节点数为vocab_size的softmax输出,从而与真实label计算损失。

    def get_masked_lm_output(bert_config, 
                            input_tensor, #BertModel的最后一层sequence_output输出model.get_sequence_output()[batch_size, seq_length, hidden_size]
                            output_weights,#输入是model.get_embedding_table(),[vocab_size,hidden_size]
                               positions, #mask词的位置
                             label_ids, #label,真实值结果
                             label_weights):
                             
      """Get loss and log probs for the masked LM."""
      # 根据positions位置获取masked词在Transformer的输出结果,即要预测的那些位置的encoder
      input_tensor = gather_indexes(input_tensor, positions)#[batch_size*max_pred_pre_seq,hidden_size]
    
      with tf.variable_scope("cls/predictions"):
        # 在输出之前添加一个带激活函数的全连接神经网络,只在预训练阶段起作用
        with tf.variable_scope("transform"):
          input_tensor = tf.layers.dense(
              input_tensor,
              units=bert_config.hidden_size,
              activation=modeling.get_activation(bert_config.hidden_act),
              kernel_initializer=modeling.create_initializer(
                  bert_config.initializer_range))
          input_tensor = modeling.layer_norm(input_tensor)
    
        # output_weights是和传入的word embedding一样的,这里再添加一个bias
        output_bias = tf.get_variable(
            "output_bias",
            shape=[bert_config.vocab_size],
            initializer=tf.zeros_initializer())
            
        logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size*max_pred_pre_seq,vocab_size]
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)#得出masked词的softmax结果,[batch_size*max_pred_pre_seq,vocab_size]
    
        # label_ids表示mask掉的Token的id,下面这部分就是根据真实值计算loss了。
        label_ids = tf.reshape(label_ids, [-1])#[batch_size*max_pred_per_seq] 
        label_weights = tf.reshape(label_weights, [-1])
    
        one_hot_labels = tf.one_hot(
            label_ids, depth=bert_config.vocab_size, dtype=tf.float32)#[batch_size*max_pred_per_seq,vocab_size]
    
        # 但是由于实际MASK的可能不到20,比如只MASK18,那么label_ids有2个0(padding),而label_weights=[1, 1, ...., 0, 0],说明后面两个label_id是padding的,计算loss要去掉,label_weights就是起一个标记作用
        per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])#[batch_size*max_pred_per_seq] 
        numerator = tf.reduce_sum(label_weights * per_example_loss) #一个batch的loss 
        denominator = tf.reduce_sum(label_weights) + 1e-5
        loss = numerator / denominator  #平均loss 
    
      return (loss, per_example_loss, log_probs)

    重要补充:预训练中的随机MASK函数

      核心思想:每个输入序列,只有最多15%的token被mask,而其中80%的机会被替换成[MASK],10%的机会保持原词不变,10%的机会随机替换为字典中的任意词。代码如何实现呢?先获取每个token的索引位置,然后随机打乱索引位置,接着取前15%的token进行替换即可。在替换中,再次利用随机函数,实现80%替换为[MASK]等,代码层面利用random函数还是比较巧妙的。

    def create_masked_lm_predictions(tokens, #list存放的sequence,例如[CLS,今, 天, 举, 行, 的, 国, 家, 发, 展, 改, 革, 委, 新, 闻, 发, 布, 会, SEP]
                                     masked_lm_prob, #代码中是0.15
                                     max_predictions_per_seq, #代码中20
                                     vocab_words, 
                                     rng): #rng=random.Random()
    
      cand_indexes = []
      # [CLS]和[SEP]不能用于MASK
      for (i, token) in enumerate(tokens):
        if token == "[CLS]" or token == "[SEP]":
          continue
        cand_indexes.append(i)
        
      #随机打乱索引顺序
      rng.shuffle(cand_indexes)
    
      output_tokens = list(tokens)
      #masked token数量,从最大mask配置数和seq长度*mask比例中取一个最小数,作为这个seq最终的mask数量
      num_to_predict = min(max_predictions_per_seq,
                           max(1, int(round(len(tokens) * masked_lm_prob))))
      
      masked_lms = []
      #covered_indexes存放被mask token的索引位置
      covered_indexes = set()
      for index in cand_indexes:
         #达到mask的数量,就停止
        if len(masked_lms) >= num_to_predict:
          break
        if index in covered_indexes:
          continue
        covered_indexes.add(index)
    
        masked_token = None
        # 80% of the time, replace with [MASK],替换为[MASK]
        if rng.random() < 0.8:
          masked_token = "[MASK]"
        else:
          # 10% of the time, keep original,保持原词
          if rng.random() < 0.5:
            masked_token = tokens[index]
          # 10% of the time, replace with random word,随机替换
          else:
            masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
            
        #将masked_token替换覆盖原token
        output_tokens[index] = masked_token
        
        #保存masked token的原索引位置,及真实的label token
        masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
    
      # 按照下标重排,保证是原来句子中出现的顺序
      masked_lms = sorted(masked_lms, key=lambda x: x.index)
    
      masked_lm_positions = []
      masked_lm_labels = []
      for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)
        
      #返回带mask的sequence tokens,被masked token的原索引位置,及原来的真实label token ,以便计算loss
      return (output_tokens, masked_lm_positions, masked_lm_labels)

    举例实现随机替换的思想:

    import random,collections
    MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
                                              ["index", "label"])
    #返回的是一个Random对象,每次再调用rng.random()都返回一个0~1的随机数,这里与bert原代码保持一致,种子都是12345
    rng=random.Random(12345)#这里rng一定要放在函数外面,这样相当于在外部完成初始化,每次调用函数才会随机生成不断变化的结果
    def create_mask_sample(sequence="",mask_prob=0.15,vocab_words=[],rng=None):
        tokens=[]
        cand_indexes = []
        for i,w in enumerate(sequence):
            cand_indexes.append(i)
            tokens.append(w)
            
        #随机打乱索引顺序
        rng.shuffle(cand_indexes)
        #mask后输出tokens
        output_tokens = list(tokens)
        #一个输入序列中需要mask的数量
        num_to_predict = int(len(tokens)*mask_prob)
      
        masked_lms = []
        #covered_indexes存放被mask token的索引位置
        covered_indexes = set()
        for index in cand_indexes:
         #达到mask的数量,就停止
            if len(masked_lms) >= num_to_predict:
                break
            if index in covered_indexes:
                continue
            covered_indexes.add(index)
    
            masked_token = None
            # 80% of the time, replace with [MASK],替换为[MASK]
            if rng.random() < 0.8:     #这里有80%的概率是满足<0.8
                masked_token = "[MASK]"
            else:                    #如果是>=0.8情况呢,这里有20%的概率
              # 剩下的概率一半保持原词,也就是10% of the time, keep original,保持原词
              if rng.random() < 0.5:
                masked_token = tokens[index]
              # 10% of the time, replace with random word,随机替换
              else:
                masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
            
            #将masked_token替换覆盖原token
            output_tokens[index] = masked_token
        
            #保存masked token的原索引位置,及真实的label token
            masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
    
        # 按照下标重排,保证是原来句子中出现的顺序
        masked_lms = sorted(masked_lms, key=lambda x: x.index)
    
        masked_lm_positions = []
        masked_lm_labels = []
        for p in masked_lms:
            masked_lm_positions.append(p.index)
            masked_lm_labels.append(p.label)
        
          #返回带mask的sequence tokens,被masked token的原索引位置,及原来的真实label token ,以便计算loss
        return (output_tokens, masked_lm_positions, masked_lm_labels)
    #举例子测试
    seq='今天下午举行的市新冠肺炎疫情防控工作领导小组新闻发布会透露:近期,多个国家和地区出现新冠肺炎确诊病例,数量持续攀升。鉴于当前境外疫情防控形势,结合上海实际,市防控工作领导小组及相关部门综合研判,进一步明确了涉外疫情防控和入境人员健康管理措施。'
    v_words=['', '', '', '3', '', '1', '', '', '', '', '', '', '',
     '', '', '', '', '', '', '', '', '', '', '', '', '',
     '', '', '', '', '', '', '', '', '', '', '', '', '',
     '', '', '', '', '', '', '', '', '', '', '8', '', '',
     '', '', '', '', '', '', '', '', '', '', '', '2', '4', '',
     '', '', '', '', '', '', '', '', '', '', '', '', '', '',
     '', '', '', '', '', '', '', '', '', '', '', '', '', '']
    output_tokens,masked_lm_positions,masked_lm_labels=create_mask_sample(sequence=seq,mask_prob=0.1,vocab_words=v_words,rng=rng)
    print(len(output_tokens))
    print(''.join(output_tokens))
    print(masked_lm_positions)
    print(masked_lm_labels)

    out:
    121
    今天下午[MASK]行的市新冠肺炎疫情防控工[MASK]领导小组新闻发布会透露:[MASK]期,多个国家和地区出现新冠[MASK]炎确诊病例[MASK]数量持[MASK]攀升[MASK]鉴于当[MASK]境外疫情防控形势,结合上海实际,市防控工[MASK]领导小组及相关部门[MASK]合研判,进一步明[MASK]了涉外疫情防控和入境人员[MASK]康管理措施。
    [4, 17, 30, 44, 50, 54, 57, 61, 82, 92, 101, 114]
    ['举', '作', '近', '肺', ',', '续', '。', '前', '作', '综', '确', '健']
    

     注意:同一段话,每调用一次都会随机生成不同的mask结果,达到随机mask目的。

    二、 Next Sentence Prediction

    get_next_sentence_output函数用于计算「任务#2」的训练 loss,这部分比较简单,只需要再额外加一层softmax输出即可。输入为 BertModel 的最后一层 pooled_output 输出([batch_size, hidden_size]),因为该任务属于二分类问题,所以只需要每个序列的第一个 token【CLS】即可。

    def get_next_sentence_output(bert_config,
                                input_tensor,#pooled_output 输出,shape=[batch_size, hidden_size]
                                labels):
      """Get loss and log probs for the next sentence prediction."""
    
     # 标签0表示 下一个句子关系成立;标签1表示 下一个句子关系不成立。这个分类器的参数在实际Fine-tuning阶段会丢弃掉
      with tf.variable_scope("cls/seq_relationship"):
      #初始化权重参数,最终的分类结果是只有2个,所以shape=[2,hidden_size]
        output_weights = tf.get_variable(
            "output_weights",
            shape=[2, bert_config.hidden_size],
            initializer=modeling.create_initializer(bert_config.initializer_range))
        output_bias = tf.get_variable(
            "output_bias", shape=[2], initializer=tf.zeros_initializer())
        
        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)#输入与权重相乘,shape=[batch_size,2]
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)#softmax输出:shape=[batch_size,2]
        
        #下面这部分就是根据真实值计算损失loss了
        labels = tf.reshape(labels, [-1])
        one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
        per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
        loss = tf.reduce_mean(per_example_loss)
        return (loss, per_example_loss, log_probs)
  • 相关阅读:
    方法的重载
    this用法
    简单的随机数 代码和笔记
    java内存简单剖析
    day 28
    day 27
    day 26
    day 25
    day 24
    day 23
  • 原文地址:https://www.cnblogs.com/gczr/p/12396992.html
Copyright © 2020-2023  润新知