• pytorch seq2seq模型中加入teacher_forcing机制


    在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加。

    目标不确定,需要在循环外加。

    decoder.py 中的修改

    """
    实现解码器
    """
    import torch.nn as nn
    import config
    import torch
    import torch.nn.functional as F
    import numpy as np
    import random
    
    
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
    
            self.embedding = nn.Embedding(num_embeddings=len(config.ns),
                                          embedding_dim=50,
                                          padding_idx=config.ns.PAD)
    
            # 需要的hidden_state形状:[1,batch_size,64]
            self.gru = nn.GRU(input_size=50,
                              hidden_size=64,
                              num_layers=1,
                              bidirectional=False,
                              batch_first=True,
                              dropout=0)
    
            # 假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64]
    
            self.fc = nn.Linear(64, len(config.ns))
    
        def forward(self, encoder_hidden,target):
    
            # 第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  # [1,batch_size,encoder_hidden_size]
            # 第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
    
            # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device)
    
            for t in range(config.max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
                decoder_outputs[:, t, :] = decoder_output_t
    
                # 获取当前时间步的预测值
                value, index = decoder_output_t.max(dim=-1)
                if random.randint(0,100) >70:    #teacher forcing机制
                    decoder_input = target[:,t].unsqueeze(-1)
                else:
                    decoder_input = index.unsqueeze(-1)  # [batch_size,1]
                # print("decoder_input:",decoder_input.size())
            return decoder_outputs, decoder_hidden
    
        def forward_step(self, decoder_input, decoder_hidden):
            '''
            计算一个时间步的结果
            :param decoder_input: [batch_size,1]
            :param decoder_hidden: [batch_size,encoder_hidden_size]
            :return:
            '''
    
            decoder_input_embeded = self.embedding(decoder_input)
            # print("decoder_input_embeded:",decoder_input_embeded.size())
    
            out, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden)
    
            # out :【batch_size,1,hidden_size】
    
            out_squeezed = out.squeeze(dim=1)  # 去掉为1的维度
            out_fc = F.log_softmax(self.fc(out_squeezed), dim=-1)  # [bathc_size,vocab_size]
            # out_fc.unsqueeze_(dim=1) #[bathc_size,1,vocab_size]
            # print("out_fc:",out_fc.size())
            return out_fc, decoder_hidden
    
        def evaluate(self, encoder_hidden):
    
            # 第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  # [1,batch_size,encoder_hidden_size]
            # 第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
            # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device)
    
            decoder_predict = []  # [[],[],[]]  #123456  ,targe:123456EOS,predict:123456EOS123
            for t in range(config.max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
                decoder_outputs[:, t, :] = decoder_output_t
    
                # 获取当前时间步的预测值
                value, index = decoder_output_t.max(dim=-1)
                decoder_input = index.unsqueeze(-1)  # [batch_size,1]
                # print("decoder_input:",decoder_input.size())
                decoder_predict.append(index.cpu().detach().numpy())
    
            # 返回预测值
            decoder_predict = np.array(decoder_predict).transpose()  # [batch_size,max_len]
            return decoder_outputs, decoder_predict
    

      seq2seq.py

    """
    完成seq2seq模型
    """
    import torch.nn as nn
    from encoder import Encoder
    from decoder import Decoder
    
    
    class Seq2Seq(nn.Module):
        def __init__(self):
            super(Seq2Seq, self).__init__()
            self.encoder = Encoder()
            self.decoder = Decoder()
    
        def forward(self, input, input_len,target):
            encoder_outputs, encoder_hidden = self.encoder(input, input_len)
            decoder_outputs, decoder_hidden = self.decoder(encoder_hidden,target)
            return decoder_outputs
    
        def evaluate(self, input, input_len):
            encoder_outputs, encoder_hidden = self.encoder(input, input_len)
            decoder_outputs, decoder_predict = self.decoder.evaluate(encoder_hidden)
            return decoder_outputs, decoder_predict
    

      train.py

    """
    进行模型的训练
    """
    import torch
    import torch.nn.functional as F
    from seq2seq import Seq2Seq
    from torch.optim import Adam
    from dataset import get_dataloader
    from tqdm import tqdm
    import config
    import numpy as np
    import pickle
    from matplotlib import pyplot as plt
    from eval import eval
    import os
    
    model = Seq2Seq().to(config.device)
    optimizer = Adam(model.parameters())
    
    if os.path.exists("./models/model.pkl"):
        model.load_state_dict(torch.load("./models/model.pkl"))
        optimizer.load_state_dict(torch.load("./models/optimizer.pkl"))
    
    loss_list = []
    
    
    def train(epoch):
        data_loader = get_dataloader(train=True)
        bar = tqdm(data_loader, total=len(data_loader))
    
        for idx, (input, target, input_len, target_len) in enumerate(bar):
            input = input.to(config.device)
            target = target.to(config.device)
            input_len = input_len.to(config.device)
            optimizer.zero_grad()
            decoder_outputs = model(input, input_len,target)  # [batch_Size,max_len,vocab_size]
            predict = decoder_outputs.view(-1, len(config.ns))
            target = target.view(-1)
            loss = F.nll_loss(predict, target, ignore_index=config.ns.PAD)
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())
            bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch, idx, np.mean(loss_list)))
    
            if idx % 100 == 0:
                torch.save(model.state_dict(), "./models/model.pkl")
                torch.save(optimizer.state_dict(), "./models/optimizer.pkl")
                pickle.dump(loss_list, open("./models/loss_list.pkl", "wb"))
    
    
    if __name__ == '__main__':
        for i in range(5):
            train(i)
            eval()
    
        plt.figure(figsize=(50, 8))
        plt.plot(range(len(loss_list)), loss_list)
        plt.show()
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    Windows如何上传代码到Github
    MSSQL的简单盲注
    各种类型文件头标准编码(转)
    Apache Flex BlazeDS(CVE-2017-5641)AFM3反序列化
    TSec《mysql client attack chain》
    # marshalsec使用
    # JDK7+ MethodHandle
    # CVE-2019-2725反序列化漏洞补丁绕过分析
    #LOF算法
    # URL异常检测
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12343829.html
Copyright © 2020-2023  润新知