import torch import torch.nn as nn import torch.nn.functional as F from config import IGNORE_ID from .attention import MultiHeadAttention from .module import PositionalEncoding, PositionwiseFeedForward from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list # filename = 'bigram_freq.pkl' # print('loading {}...'.format(filename)) # with open(filename, 'rb') as file: # bigram_freq = pickle.load(file) class Decoder(nn.Module): ''' A decoder model with self attention mechanism. ''' def __init__( self, sos_id=0, eos_id=1, n_tgt_vocab=4335, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64, d_model=512, d_inner=2048, dropout=0.1, tgt_emb_prj_weight_sharing=True, pe_maxlen=5000): super(Decoder, self).__init__() # parameters 参数实例化 self.sos_id = sos_id # Start of Sentence self.eos_id = eos_id # End of Sentence self.n_tgt_vocab = n_tgt_vocab self.d_word_vec = d_word_vec self.n_layers = n_layers self.n_head = n_head self.d_k = d_k self.d_v = d_v self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing self.pe_maxlen = pe_maxlen self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec) self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) self.dropout = nn.Dropout(dropout) self.layer_stack = nn.ModuleList([ DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)]) #解码器个数 self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) #线性变换 nn.init.xavier_normal_(self.tgt_word_prj.weight) #初始化 if tgt_emb_prj_weight_sharing: #默认为true # Share the weight matrix between target word embedding & the final logit dense layer self.tgt_word_prj.weight = self.tgt_word_emb.weight #将目标词嵌入权重共享给线性函数的权重 self.x_logit_scale = (d_model ** -0.5) #? else: self.x_logit_scale = 1. def preprocess(self, padded_input): #预处理 """Generate decoder input and output label from padded_input Add <sos> to decoder input, and add <eos> to decoder output label """ ys = [y[y != IGNORE_ID] for y in padded_input] # parse padded ys IGNOR_ID=-1 # prepare input and output word sequences with sos/eos IDs eos = ys[0].new([self.eos_id]) #定义新的零阶tensor # .new():创建一个新的Tensor,该Tensor的type和device都和原有Tensor一致,且无内容。 sos = ys[0].new([self.sos_id]) ys_in = [torch.cat([sos, y], dim=0) for y in ys] #合并两个tensor,添加起始标签 ys_out = [torch.cat([y, eos], dim=0) for y in ys] #添加结束标签 # padding for ys with -1 # pys: utt x olen ys_in_pad = pad_list(ys_in, self.eos_id) #ys_in:填充对象;self.eos_id:填充值 ys_out_pad = pad_list(ys_out, IGNORE_ID) assert ys_in_pad.size() == ys_out_pad.size() #assert判断后面代码的布尔值,若为假就报错 return ys_in_pad, ys_out_pad #返回添加标签和填充后的数据 def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths, return_attns=False): """ Args: padded_input: N x To encoder_padded_outputs: N x Ti x H Returns: """ dec_slf_attn_list, dec_enc_attn_list = [], [] #定义解码器注意力和编码解码注意力列表 # Get Deocder Input and Output ys_in_pad, ys_out_pad = self.preprocess(padded_input) #提取预处理后的数据 # Prepare masks non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) #对输入mask slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) #对目标序列mask slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, seq_q=ys_in_pad, pad_idx=self.eos_id) #对key mask slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) #自注意力mask output_length = ys_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) #编码解码注意力mask # Forward dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + self.positional_encoding(ys_in_pad)) #输入等词嵌入加位置编码 for dec_layer in self.layer_stack: #进入decoder层 dec_outpsk=slf_aut, dec_slf_attn, dec_enc_attn = dec_layer( dec_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, slf_attn_mattn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if return_attns: #默认False dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] # before softmax seq_logit = self.tgt_word_prj(dec_output)#编码器的输出放入线性网络中 # Return pred, gold = seq_logit, ys_out_pad #得到目标值和预测值 if return_attns: return pred, gold, dec_slf_attn_list, dec_enc_attn_list return pred, gold def recognize_beam(self, encoder_outputs, char_list, args): """Beam search, decode one utterence now. Args: encoder_outputs: T x H char_list: list of character args: args.beam Returns: nbest_hyps: """ # search params beam = args.beam_size nbest = args.nbest if args.decode_max_len == 0: maxlen = encoder_outputs.size(0) else: maxlen = args.decode_max_len encoder_outputs = encoder_outputs.unsqueeze(0) #unsqueeze(0)对零维添加一个维度 # prepare sos # 在数据中添加起始标志 ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long() #.ones(size):生成一个全是1的tensor;a.type_as(b):将a的数据类型转换为b的数据类型; #a.fill_(b):将a中的数据替换为b;long():数据类型 # yseq: 1xT hyp = {'score': 0.0, 'yseq': ys} hyps = [hyp] ended_hyps = [] for i in range(maxlen): hyps_best_kept = [] for hyp in hyps: ys = hyp['yseq'] # 1 x i # last_id = ys.cpu().numpy()[0][-1] # freq = bigram_freq[last_id] # freq = torch.log(torch.from_numpy(freq)) # # print(freq.dtype) # freq = freq.type(torch.float).to(device) # print(freq.dtype) # print('freq.size(): ' + str(freq.size())) # print('freq: ' + str(freq)) # -- Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1 slf_attn_mask = get_subsequent_mask(ys) # -- Forward dec_output = self.dropout( self.tgt_word_emb(ys) * self.x_logit_scale + self.positional_encoding(ys)) for dec_layer in self.layer_stack: dec_output, _, _ = dec_layer( dec_output, encoder_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None) seq_logit = self.tgt_word_prj(dec_output[:, -1]) # local_scores = F.log_softmax(seq_logit, dim=1) local_scores = F.log_softmax(seq_logit, dim=1) # print('local_scores.size(): ' + str(local_scores.size())) # local_scores += freq # print('local_scores: ' + str(local_scores)) # topk scores local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long() new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq'] new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j]) # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # end for hyp in hyps hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'] = torch.cat([hyp['yseq'], torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()], dim=1) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][0, -1] == self.eos_id: ended_hyps.append(hyp) else: remained_hyps.append(hyp) hyps = remained_hyps # if len(hyps) > 0: # print('remeined hypothes: ' + str(len(hyps))) # else: # print('no hypothesis. Finish decoding.') # break # # for hyp in hyps: # print('hypo: ' + ''.join([char_list[int(x)] # for x in hyp['yseq'][0, 1:]])) # end for i in range(maxlen) nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[ :min(len(ended_hyps), nbest)] # compitable with LAS implementation for hyp in nbest_hyps: hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist() return nbest_hyps class DecoderLayer(nn.Module): ''' Compose with three layers ''' def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): super(DecoderLayer, self).__init__() self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): dec_output, dec_slf_attn = self.slf_attn( dec_input, dec_input, dec_input, mask=slf_attn_mask) dec_output *= non_pad_mask dec_output, dec_enc_attn = self.enc_attn( dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) dec_output *= non_pad_mask dec_output = self.pos_ffn(dec_output) dec_output *= non_pad_mask return dec_output, dec_slf_attn, dec_enc_attn