• pytorch seq2seq闲聊机器人beam search返回结果


    decoder.py

    """
    实现解码器
    """
    import heapq
    
    import torch.nn as nn
    import config
    import torch
    import torch.nn.functional as F
    import numpy as np
    import random
    from chatbot.attention import Attention
    
    
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder,self).__init__()
    
            self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
                                          embedding_dim=config.chatbot_decoder_embedding_dim,
                                          padding_idx=config.target_ws.PAD)
    
            #需要的hidden_state形状:[1,batch_size,64]
            self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
                              hidden_size=config.chatbot_decoder_hidden_size,
                              num_layers=config.chatbot_decoder_number_layer,
                              bidirectional=False,
                              batch_first=True,
                              dropout=config.chatbot_decoder_dropout)
    
            #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64]
    
            self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
            self.attn = Attention(method="general")
            self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False)
    
        def forward(self, encoder_hidden,target,encoder_outputs):
            # print("target size:",target.size())
            #第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  #[1,batch_size,128*2]
            #第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device)         #[batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
    
            #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device)
    
            if random.random()>0.5:    #teacher_forcing机制
    
                for t in range(config.chatbot_target_max_len):
                    decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
                    decoder_outputs[:,t,:] = decoder_output_t
    
    
                    #获取当前时间步的预测值
                    value,index = decoder_output_t.max(dim=-1)
                    decoder_input = index.unsqueeze(-1)  #[batch_size,1]
                    # print("decoder_input:",decoder_input.size())
            else:
                for t in range(config.chatbot_target_max_len):
                    decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
                    decoder_outputs[:, t, :] = decoder_output_t
                    #把真实值作为下一步的输入
                    decoder_input = target[:,t].unsqueeze(-1)
                    # print("decoder_input size:",decoder_input.size())
            return decoder_outputs,decoder_hidden
    
    
        def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
            '''
            计算一个时间步的结果
            :param decoder_input: [batch_size,1]
            :param decoder_hidden: [1,batch_size,128*2]
            :return:
            '''
    
            decoder_input_embeded = self.embedding(decoder_input)
            # print("decoder_input_embeded:",decoder_input_embeded.size())
    
            #out:[batch_size,1,128*2]
            #decoder_hidden :[1,bathc_size,128*2]
            # print(decoder_hidden.size())
            out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden)
    
            ##### 开始attention ############
            ### 1. 计算attention weight
            attn_weight = self.attn(decoder_hidden,encoder_outputs)  #[batch_size,1,encoder_max_len]
            ### 2. 计算context vector
            #encoder_ouputs :[batch_size,encoder_max_len,128*2]
            context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
            ### 3. 计算 attention的结果
            #[batch_size,128*2]  #context_vector:[batch_size,128*2] --> 128*4
            #attention_result = [batch_size,128*4] --->[batch_size,128*2]
            attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
            # attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
            #### attenion 结束
    
            # print("decoder_hidden size:",decoder_hidden.size())
            #out :【batch_size,1,hidden_size】
    
            # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
            out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
            # print("out_fc:",out_fc.size())
            return out_fc,decoder_hidden
    
        def evaluate(self,encoder_hidden,encoder_outputs):
    
            # 第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  # [1,batch_size,128*2]
            # 第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
            # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
                config.device)
    
            predict_result = []
            for t in range(config.chatbot_target_max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
                decoder_outputs[:, t, :] = decoder_output_t
    
                # 获取当前时间步的预测值
                value, index = decoder_output_t.max(dim=-1)
                predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
                decoder_input = index.unsqueeze(-1)  # [batch_size,1]
                # print("decoder_input:",decoder_input.size())
                # predict_result.append(decoder_input)
            #把结果转化为ndarray,每一行是一条预测结果
            predict_result = np.array(predict_result).transpose()
            return decoder_outputs, predict_result
    
        def evaluate_with_beam_search(self, encoder_hidden, encoder_outputs):
            """
            使用beam search完成评估,只能输入一个句子,得到一个输出
            :param encoder_hidden:
            :param encoder_outputs:
            :return:
            """
            # 第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  # [1,batch_size,128*2]
            # 第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            assert batch_size == 1, "beam search的过程中,batch_size只能为1"
            decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)  # [batch_size,1]
    
            prev_beam = Beam()
            prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)
    
            while True:
                cur_beam = Beam()
                for prob, complete, seq_list, decoder_input, decoder_hidden in prev_beam:
                    if complete:  # 有可能前一次已经到达eos了,但是概率不是最大的
                        cur_beam.add(prob, complete, seq_list, decoder_input, decoder_hidden)
                    else:
                        decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
    
                        value, index = torch.topk(decoder_output_t, config.beam_width)
                        # print("value index size:",value[0].size(),index[0].size())
                        for m, n in zip(value[0], index[0]):
                            # print("m,n size:",m.size(),n.size(),m,n)
                            cur_prob = prob * m.item()
                            decoder_input = torch.LongTensor([[n.item()]]).to(config.device)
                            cur_seq_list = seq_list + [decoder_input]
                            if n == config.target_ws.EOS:
                                cur_complete = True
                            else:
                                cur_complete = False
                            cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden)
    
                best_prob, best_complete, best_seq, _, _ = max(cur_beam)
                if best_complete or len(best_seq) - 1 == config.chatbot_target_max_len:
    
                    best_seq = [i.item() for i in best_seq]
                    if best_seq[0] == config.target_ws.SOS:
                        best_seq = best_seq[1:]
                    if best_seq[-1] == config.target_ws.EOS:
                        best_seq = best_seq[:-1]
                    return best_seq
    
    
                else:
                    prev_beam = cur_beam
    
    
    class Beam:
        """保存每一个时间步的数据"""
    
        def __init__(self):
            self.heapq = list()
            self.beam_width = config.beam_width
    
        def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
            heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
            # 保证最终只有一个beam width个结果
            if len(self.heapq) > self.beam_
                heapq.heappop(self.heapq)
    
        def __iter__(self):
            for item in self.heapq:
                yield item
    

      seq2seq.py

    """
    完成seq2seq模型
    """
    import torch.nn as nn
    from chatbot.encoder import Encoder
    from chatbot.decoder import Decoder
    
    
    class Seq2Seq(nn.Module):
        def __init__(self):
            super(Seq2Seq,self).__init__()
            self.encoder = Encoder()
            self.decoder = Decoder()
    
        def forward(self, input,input_len,target):
            encoder_outputs,encoder_hidden = self.encoder(input,input_len)
            decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
            return decoder_outputs
    
        def evaluate(self,input,input_len):
            encoder_outputs, encoder_hidden = self.encoder(input, input_len)
            decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
            return decoder_outputs,predict_result
    
    
        def evaluate_with_beam_search(self,input,input_len):
            encoder_outputs, encoder_hidden = self.encoder(input, input_len)
            best_seq = self.decoder.evaluate_with_beam_search(encoder_hidden, encoder_outputs)
            return best_seq
    

      eval.py

    """
    进行模型的评估
    """
    
    import torch
    import torch.nn.functional as F
    from chatbot.dataset import get_dataloader
    from tqdm import tqdm
    import config
    import numpy as np
    import pickle
    from chatbot.seq2seq import Seq2Seq
    
    def eval():
        model = Seq2Seq().to(config.device)
        model.eval()
        model.load_state_dict(torch.load("./models/model.pkl"))
    
        loss_list = []
        data_loader = get_dataloader(train=False)
        bar = tqdm(data_loader,total=len(data_loader),desc="当前进行评估")
        with torch.no_grad():
            for idx,(input,target,input_len,target_len) in enumerate(bar):
                input = input.to(config.device)
                target = target.to(config.device)
                input_len = input_len.to(config.device)
    
                decoder_outputs,predict_result = model.evaluate(input,input_len) #[batch_Size,max_len,vocab_size]
                loss = F.nll_loss(decoder_outputs.view(-1,len(config.target_ws)),target.view(-1),ignore_index=config.input_ws.PAD)
                loss_list.append(loss.item())
                bar.set_description("idx:{} loss:{:.6f}".format(idx,np.mean(loss_list)))
        print("当前的平均损失为:",np.mean(loss_list))
    
    
    def interface():
        from chatbot.cut_sentence import cut
        import config
        #加载模型
        model = Seq2Seq().to(config.device)
        model.eval()
        model.load_state_dict(torch.load("./models/model.pkl"))
    
        #准备待预测的数据
        while True:
            origin_input =input("me>>:")
            # if "你是谁" in origin_input or "你叫什么" in origin_input:
            #     result = "我是小智。"
            # elif "你好" in origin_input or "hello" in origin_input:
            #     result = "Hello"
            # else:
            _input = cut(origin_input, by_word=True)
            input_len = torch.LongTensor([len(_input)]).to(config.device)
            _input = torch.LongTensor([config.input_ws.transform(_input,max_len=config.chatbot_input_max_len)]).to(config.device)
    
            outputs,predict = model.evaluate(_input,input_len)
            result = config.target_ws.inverse_transform(predict[0])
            print("chatbot>>:",result)
    
    
    def interface_with_beamsearch():
        from chatbot.cut_sentence import cut
        import config
        # 加载模型
        model = Seq2Seq().to(config.device)
        model.eval()
        model.load_state_dict(torch.load("./models/model.pkl"))
    
        # 准备待预测的数据
        while True:
            origin_input = input("me>>:")
            _input = cut(origin_input, by_word=True)
            input_len = torch.LongTensor([len(_input)]).to(config.device)
            _input = torch.LongTensor([config.input_ws.transform(_input, max_len=config.chatbot_input_max_len)]).to(
                config.device)
    
            best_seq = model.evaluate_with_beam_search(_input, input_len)
            result = config.target_ws.inverse_transform(best_seq)
            print("chatbot>>:", result)
    
    
    
    
    if __name__ == '__main__':
        # interface()
        interface_with_beamsearch()
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    eclipse使用svn
    yum安装mysql
    spring中aop使用
    mybatis定义拦截器
    横扫页面的三大标签
    springmvc日期格式化
    springmvc笔记
    springboot跳转jsp页面
    常用网址
    CentOS Android Studio桌面图标的创建
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12386164.html
Copyright © 2020-2023  润新知