• Transformer[vanilla]


    0 引言


    transformer这个框架现在可谓是遍地开花,继最开始的AE,CNN,RNN,到现在的transformer,该框架从nlp席卷CV,乃至ASR领域。
    本文以The Illustrated Transformer【译】The Annotated Transformer为来源,主要从总到分的角度去阅读代码。

    其实就是觉得The Annotated Transformer写的非常好,但是诸多教程都喜欢先展示一堆材料,然后最后告诉你组装结果;感觉不是很符合自己的理解习惯,因为到了最后组装的时候才发现,前面一堆碎片

    github的地址:https://github.com/harvardnlp/annotated-transformer

    1 基于最外层进行示意

    1.1 包引用

    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math, copy, time
    from torch.autograd import Variable
    import matplotlib.pyplot as plt
    import seaborn
    seaborn.set_context(context="talk")
    

    这个没得说。

    1.2 构建模型 make_model

    大多数神经网络翻译模型,都是encoder-decoder 结构,如最经典的seq2seq,将原始向量映射到中间特征向量,然后通过中间特征向量解码成目标列,所以这里核心就是EncoderDecoder部分,以它为拼图的中心

    def make_model(src_vocab, tgt_vocab, N=6, 
                   d_model=512, d_ff=2048, h=8, dropout=0.1):
        "Helper: Construct a model from hyperparameters."
        # 1-方便模块复制
        c = copy.deepcopy
    
        # 2-三个额外的组件:多头注意力;逐位置前向;位置编码
        attn = MultiHeadedAttention(h, d_model)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        position = PositionalEncoding(d_model, dropout)
    
        # 3-模型结构
        model = EncoderDecoder(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
            Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                                 c(ff), dropout), N),
            nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
            nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
            Generator(d_model, tgt_vocab))
        
        # This was important from their code. 
        # Initialize parameters with Glorot / fan_avg.
        # 4-模型中可训练参数初始化
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform(p)
        return model
    

    上述代码可分成2部分:三个额外的组件:

    • 多头注意力;逐位置前向;位置编码;
    • 模型结构:编码解码器,其中分为编码器、解码器、序列层1,序列层2,生成器

    1.3 构造数据 data_gen

    主要是造些假数据,

    def data_gen(V, batch, nbatches):
        "Generate random data for a src-tgt copy task."
        for i in range(nbatches):
            data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))
            data[:, 0] = 1
            src = Variable(data, requires_grad=False)
            tgt = Variable(data, requires_grad=False)
            yield Batch(src, tgt, 0)
    

    1.4 训练

    1.4.1 简单的loss计算 SimpleLossCompute

    class SimpleLossCompute:
        "A simple loss compute and train function."
        def __init__(self, generator, criterion, opt=None):
            self.generator = generator
            self.criterion = criterion
            self.opt = opt
            
        def __call__(self, x, y, norm):
            # 获取x对应的概率值
            x = self.generator(x)
            # 计算loss
            loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                                  y.contiguous().view(-1)) / norm
            # 反向bp计算,获取梯度值
            loss.backward()
            if self.opt is not None:
                self.opt.step() # 梯度更新到模型
                self.opt.optimizer.zero_grad() # 清零梯度值,准备用于下一次梯度计算
            return loss.data[0] * norm
    

    1.4.2 训练过程

    # Train the simple copy task.
    V = 11
    # 这里增加了标签平滑的函数
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    
    model = make_model(V, V, N=2)
    
    #增加了优化器的封装
    model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    
    for epoch in range(10):
        model.train()
        run_epoch(data_gen(V, 30, 20), model, 
                  # 这里增加了简单loss计算函数
                  SimpleLossCompute(model.generator, criterion, model_opt))
    
        model.eval()
        print(run_epoch(data_gen(V, 30, 5), model, 
                        SimpleLossCompute(model.generator, criterion, None)))
    

    1.6 推理 greedy_decode

    def greedy_decode(model, src, src_mask, max_len, start_symbol):
    
        # 调用模型的编码器
        memory = model.encode(src, src_mask)
    
        ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
        for i in range(max_len-1):
    
            # 不断地进行解码生成下一个词
            out = model.decode(memory, src_mask, 
                               Variable(ys), 
                               Variable(subsequent_mask(ys.size(1))
                                        .type_as(src.data)))
    
            # 调用模型的生成器获取对应的概率
            prob = model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim = 1)
            next_word = next_word.data[0]
            ys = torch.cat([ys, 
                            torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        return ys
    
    model.eval()
    src = Variable(torch.LongTensor([[1,2,3,4,5,6,7,8,9,10]]) )
    src_mask = Variable(torch.ones(1, 1, 10) )
    print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))
    

    2 EncoderDecoder及内部部分

    2.1 EncoderDecoder结构

    class EncoderDecoder(nn.Module):
        """
        A standard Encoder-Decoder architecture. Base for this and many 
        other models.
        """
        #接受1)编码器、2)解码器、3)源embedding、4)目标embedding、5)生成器
        def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
            super(EncoderDecoder, self).__init__()
            self.encoder = encoder
            self.decoder = decoder
            self.src_embed = src_embed
            self.tgt_embed = tgt_embed
            self.generator = generator
            
        #1)先对源及源mask进行编码;2)再对结果进行解码:【result,src_mask; tgt,tgt_mask】
        def forward(self, src, tgt, src_mask, tgt_mask):
            "Take in and process masked src and target sequences."
            return self.decode(self.encode(src, src_mask), src_mask,
                                tgt, tgt_mask)
    
        #编码器调用
        def encode(self, src, src_mask):
            return self.encoder(self.src_embed(src), src_mask)
        
        #解码器调用
        def decode(self, memory, src_mask, tgt, tgt_mask):
            return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
    

    2.2 前置函数及层

    2.2.1 复制模块函数 clones

    clones就是便携式的基于list增加多个相同的模块组件

    def clones(module, N):
        "Produce N identical layers."
        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
    

    2.2.2 层归一化 LayerNorm

    LN的公式:$$a*\frac{x-\bar{x}}{var(x)+eps}+b$$

    class LayerNorm(nn.Module):
        "Construct a layernorm module (See citation for details)."
        def __init__(self, features, eps=1e-6):
            super(LayerNorm, self).__init__()
            self.a_2 = nn.Parameter(torch.ones(features))
            self.b_2 = nn.Parameter(torch.zeros(features))
            self.eps = eps
    
        def forward(self, x):
            
            mean = x.mean(-1, keepdim=True)
            std = x.std(-1, keepdim=True)
            return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
    

    2.2.3 子层连接 SublayerConnection

    class SublayerConnection(nn.Module):
        """
        A residual connection followed by a layer norm.
        Note for code simplicity the norm is first as opposed to last.
        """
        def __init__(self, size, dropout):
            super(SublayerConnection, self).__init__()
            self.norm = LayerNorm(size)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, x, sublayer):
            "Apply residual connection to any sublayer with the same size."
            return x + self.dropout(sublayer(self.norm(x)))
    

    2.2.4 嵌入 Embeddings

    将输入单词和输出单词都转换成embedding,维度为\(d_{model}\),并对这些embedding的权重放大\(\sqrt{d_{model}}\)

    class Embeddings(nn.Module):
        def __init__(self, d_model, vocab):
            super(Embeddings, self).__init__()
            self.lut = nn.Embedding(vocab, d_model)
            self.d_model = d_model
    
        def forward(self, x):
            return self.lut(x) * math.sqrt(self.d_model)
    

    2.2.5 subsequent_mask组件及其例子

    构建mask,保证只要下三角部分

    def subsequent_mask(size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
        return torch.from_numpy(subsequent_mask) == 0
    

    plt.figure(figsize=(5,5))
    plt.imshow(subsequent_mask(20)[0])
    

    2.3 编码

    2.3.1 编码器

    将输入的layer进行重复,然后以前向形式x输入,最后获得的结果进行LayerNorm规范化,并输出

    class Encoder(nn.Module):
        "Core encoder is a stack of N layers"
        def __init__(self, layer, N):
            super(Encoder, self).__init__()
            self.layers = clones(layer, N)
            self.norm = LayerNorm(layer.size)
            
        def forward(self, x, mask):
            "Pass the input (and mask) through each layer in turn."
            for layer in self.layers:
                x = layer(x, mask)
            return self.norm(x)
    

    2.3.2 编码层

    如图所示,

    class EncoderLayer(nn.Module):
        "Encoder is made up of self-attn and feed forward (defined below)"
        def __init__(self, size, self_attn, feed_forward, dropout):
            super(EncoderLayer, self).__init__()
            self.self_attn = self_attn
            self.feed_forward = feed_forward
            self.sublayer = clones(SublayerConnection(size, dropout), 2)
            self.size = size
    
        def forward(self, x, mask):
            "Follow Figure 1 (left) for connections."
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
            return self.sublayer[1](x, self.feed_forward)
    

    2.4 解码

    2.4.1 解码器

    class Decoder(nn.Module):
        "Generic N layer decoder with masking."
        def __init__(self, layer, N):
            super(Decoder, self).__init__()
            self.layers = clones(layer, N)
            self.norm = LayerNorm(layer.size)
            
        def forward(self, x, memory, src_mask, tgt_mask):
            for layer in self.layers:
                x = layer(x, memory, src_mask, tgt_mask)
            return self.norm(x)
    

    2.4.2 解码层

    class DecoderLayer(nn.Module):
        "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
        def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
            super(DecoderLayer, self).__init__()
            self.size = size
            self.self_attn = self_attn
            # 其实在调用时候,src_attn 就是self_attn的另一个实例,并不是新的attention
            self.src_attn = src_attn  
            self.feed_forward = feed_forward
            self.sublayer = clones(SublayerConnection(size, dropout), 3)
     
        def forward(self, x, memory, src_mask, tgt_mask):
            "Follow Figure 1 (right) for connections."
            m = memory
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
            x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
            return self.sublayer[2](x, self.feed_forward)
    

    2.5 生成器

    class Generator(nn.Module):
        "Define standard linear + softmax generation step."
        def __init__(self, d_model, vocab):
            super(Generator, self).__init__()
            self.proj = nn.Linear(d_model, vocab)
    
        def forward(self, x):
            return F.log_softmax(self.proj(x), dim=-1)
    

    3 MultiHeadedAttention

    3.1 前置attention


    一个attention函数可以解释成映射一个query和一个kv集合到输出,其中query,keys,values都是向量。输出结果为为values的权重和,其中权重是通过query和对应的key计算得到的,其中queries和keys的维度都为\(d_k\),values的维度为\(d_v\):

    • 1)先计算query和所有keys的点积;
    • 2)上述点积结果除以\(\sqrt{d_k}\)
      1. 上述结果输入到softmax获得每个values的权重。

    为了一次性同时计算所有的queries,将queries放入矩阵\(Q\),对应的keys和values为\(K\),\(V\):

    \[Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V \]

    一般而言,用的最多的attention函数无非就是1)加法;2)点积法。其中本文用的就是点积法,只不过其中缩放因子为\(\frac{1}{\sqrt{d_k}}\)。加法attention是使用一个单层隐藏层的feed-forward前向网络计算。这两种其实理论上是相似的,不过点积attention更快,也更省空间,因为它可以调用优化过的矩阵相乘代码。

    对于\(d_k\)值较小的情况下,两个机制计算耗时是相似的,加法attention反而由于点积法。我么假设\(d_k\)很大,那么这时候点积值会变得很大,会将softmax函数推到梯度极端小的区域,为了消除这个影响,所以将点积结果乘以\(\frac{1}{\sqrt{d_k}}\)

    def attention(query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        d_k = query.size(-1)
    
        # 1-MatMul & 2-Scale
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(d_k)
    
        # 3-Mask 可选的
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
    
        # 4-Softmax
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
    
        # 5-MatMul
        return torch.matmul(p_attn, value), p_attn
    

    3.2 MultiHeadedAttention组件

    \[MultiHead(Q,K,V) = Concat(head_1,head_2,...,head_n)W_O \]

    其中

    \[head_i = Attention(QW_i^Q,KW_i^K,VW_i^V) \]

    这里,,
    其中\(h=8\),并且\(d_k=d_v=\frac{d_{model}}{h}=64\),则
    \(W_i^Q\in R^{512*64}\),\(W_i^K\in R^{512*64}\),\(W_i^V\in R^{512*64}\),\(W^O\in R^{512*512}\)
    因为减少了每个head的维度,所以总的计算量还是与单个head(全维度)差不多。

    class MultiHeadedAttention(nn.Module):
        def __init__(self, h, d_model, dropout=0.1):
            "Take in model size and number of heads."
            super(MultiHeadedAttention, self).__init__()
            assert d_model % h == 0
    
            # We assume d_v always equals d_k
            # 其中 d_model=64*8,  d_k=64,  h=8
            self.d_k = d_model // h
            self.h = h
    
            # 创建4个 64*64的全连接层
            self.linears = clones(nn.Linear(d_model, d_model), 4)
            self.attn = None
            self.dropout = nn.Dropout(p=dropout)
            
        def forward(self, query, key, value, mask=None):
            "Implements Figure 2"
            if mask is not None:
                # Same mask applied to all h heads.
                mask = mask.unsqueeze(1)
            # 获取当前batch的大小
            nbatches = query.size(0)
            
            # 1) Do all the linear projections in batch from d_model => h x d_k 
            # 通过zip将(query, key, value)与前3个进行一一对应,即[l(query),l(key),l(value)]
            query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k ).transpose(1, 2)
                                          for l, x in zip(self.linears, (query, key, value))]
            
            # 2) Apply attention on all the projected vectors in batch. 
            # 以batch形式计算attention
            x, self.attn = attention(query, key, value, mask=mask, 
                                     dropout=self.dropout)
            
            # 3) "Concat" using a view and apply a final linear. 
            # 得到的权重矩阵是 [nbatches,-1,512]
            x = x.transpose(1, 2).contiguous() \
                 .view(nbatches, -1, self.h * self.d_k)
    
            return self.linears[-1](x)
    

    4 PositionwiseFeedForward

    逐点相乘

    其中输入和输出都是\(d_model=512\),而内层维度是\(d_{ff}=2048\),即\(W_1\in R^{512*2048}\)

    class PositionwiseFeedForward(nn.Module):
        "Implements FFN equation."
        def __init__(self, d_model, d_ff, dropout=0.1):
            super(PositionwiseFeedForward, self).__init__()
            self.w_1 = nn.Linear(d_model, d_ff)
            self.w_2 = nn.Linear(d_ff, d_model)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, x):
            return self.w_2(self.dropout(F.relu(self.w_1(x))))
    

    5 PositionalEncoding

    5.1 位置编码组件

    因为transformer即没rnn,也不是cnn,所以为了让模型获取到序列顺序,就需要将位置的相对信息或者绝对信息注入到 模型中,所以才需要增加"位置编码"到embedding部分,位置编码的维度和embedding一样都是\(d_{model}\)

    本文中采用余弦函数来获取不同位置,其中\(pos\)就是位置,\(i\)是维度,即每个位置的维度信息都对应一个波形,选择这个函数是因为假设它可以让模型很容易学到在基于固定的偏移量\(k\)基础上的相对位置,\(PE_{pos+k}\)可以解释成\(PE_{pos}\)的线性函数。经过实验dropout=0.1时最佳。

    class PositionalEncoding(nn.Module):
        "Implement the PE function."
        def __init__(self, d_model, dropout, max_len=5000):
            super(PositionalEncoding, self).__init__()
            self.dropout = nn.Dropout(p=dropout)
            
            # Compute the positional encodings once in log space.
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2) *
                                 -(math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)
            
        def forward(self, x):
            # 输入加上位置信息
            # 在实现上,这里x是单词的embedding值,即在embedding上加上位置信息
            x = x + Variable(self.pe[:, :x.size(1)], 
                             requires_grad=False)
            return self.dropout(x)
    

    5.2 例子

    plt.figure(figsize=(15, 5))
    pe = PositionalEncoding(20, 0)
    y = pe.forward(Variable(torch.zeros(1, 100, 20)))
    plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())
    plt.legend(["dim %d"%p for p in [4,5,6,7]])
    None
    

    6 LabelSmoothing

    6.1 组件

    标签平滑,虽然会让模型很困惑,因为模型会变得更加不确定,但是能提高准确度和BLEU分数。
    主要用于将预测的值进行平滑,并与训练集的真实label求KL散度 loss

    class LabelSmoothing(nn.Module):
        "Implement label smoothing."
        def __init__(self, size, padding_idx, smoothing=0.0):
            super(LabelSmoothing, self).__init__()
            self.criterion = nn.KLDivLoss(size_average=False)
            self.padding_idx = padding_idx
            self.confidence = 1.0 - smoothing # 置信度
            self.smoothing = smoothing        # 平滑系数
            self.size = size                  # 多少列,即一个样本多少类别
            self.true_dist = None
            
        def forward(self, x, target):
            # x为预测label,target为真实label
            # true_dist为平滑后的label,
            # 本函数输出的是loss值
            assert x.size(1) == self.size
            true_dist = x.data.clone() # @1
    
            true_dist.fill_(self.smoothing / (self.size - 2)) # @2
    
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) # @3
    
            true_dist[:, self.padding_idx] = 0 # @4
            mask = torch.nonzero(target.data == self.padding_idx) # @5
            if mask.dim() > 0:
                true_dist.index_fill_(0, mask.squeeze(), 0.0) # @6
    
            self.true_dist = true_dist
            return self.criterion(x, Variable(true_dist, requires_grad=False)) # @7
    

    6.2 例子

    # Example of label smoothing.
    # 5个类别,不padding,平滑系数0.4
    crit = LabelSmoothing(5, 0, 0.4)
    
    # 假设模型的预测值为如下:
    predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
                                 [0, 0.2, 0.7, 0.1, 0], 
                                 [0, 0.2, 0.7, 0.1, 0]])
    
    # 获取标签平滑后的loss
    v = crit(Variable(predict.log()), 
             Variable(torch.LongTensor([2, 1, 0])))
    
    # Show the target distributions expected by the system.
    plt.imshow(crit.true_dist)
    

    1)@1如下

    2)@2如下

    3)@3如下

    4)@4如下

    5)@5如下

    6)@6如下

    7)@7如下

    crit = LabelSmoothing(5, 0, 0.1)
    def loss(x):
        d = x + 3 * 1
        predict = torch.FloatTensor([[0, x / d, 1 / d, 1 / d, 1 / d],
                                     ])
        #print(predict)
        return crit(Variable(predict.log()),
                     Variable(torch.LongTensor([1]))).data[0]
    plt.plot(np.arange(1, 100), [loss(x) for x in range(1, 100)])
    

    7 优化器

    7.1 优化器组件

    class NoamOpt:
        "Optim wrapper that implements rate."
        def __init__(self, model_size, factor, warmup, optimizer):
            self.optimizer = optimizer
            self._step = 0
            self.warmup = warmup
            self.factor = factor
            self.model_size = model_size
            self._rate = 0
            
        def step(self):
            "Update parameters and rate"
            self._step += 1
            rate = self.rate()
            for p in self.optimizer.param_groups:
                p['lr'] = rate
            self._rate = rate
            self.optimizer.step()
            
        def rate(self, step = None):
            "Implement `lrate` above"
            if step is None:
                step = self._step
            return self.factor * \
                (self.model_size ** (-0.5) *
                min(step ** (-0.5), step * self.warmup ** (-1.5)))
            
    def get_std_opt(model):
        return NoamOpt(model.src_embed[0].d_model, 2, 4000,
                torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    

    7.2 例子

    # Three settings of the lrate hyperparameters.
    opts = [NoamOpt(512, 1, 4000, None), 
            NoamOpt(512, 1, 8000, None),
            NoamOpt(256, 1, 4000, None)]
    plt.plot(np.arange(1, 20000), [[opt.rate(i) for opt in opts] for i in range(1, 20000)])
    plt.legend(["512:4000", "512:8000", "256:4000"])
    None
    

    8 训练

    8.1 Batch

    class Batch:
        "Object for holding a batch of data with mask during training."
        def __init__(self, src, trg=None, pad=0):
            self.src = src
            self.src_mask = (src != pad).unsqueeze(-2)
            if trg is not None:
                self.trg = trg[:, :-1]
                self.trg_y = trg[:, 1:]
                self.trg_mask = \
                    self.make_std_mask(self.trg, pad)
                self.ntokens = (self.trg_y != pad).data.sum()
        
        @staticmethod
        def make_std_mask(tgt, pad):
            "Create a mask to hide padding and future words."
            tgt_mask = (tgt != pad).unsqueeze(-2)
            tgt_mask = tgt_mask & Variable(
                subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
            return tgt_mask
    

    8.2 run_epoch

    def run_epoch(data_iter, model, loss_compute):
        "Standard Training and Logging Function"
        start = time.time()
        total_tokens = 0
        total_loss = 0
        tokens = 0
        for i, batch in enumerate(data_iter):
            out = model.forward(batch.src, batch.trg, 
                                batch.src_mask, batch.trg_mask)
            loss = loss_compute(out, batch.trg_y, batch.ntokens)
            total_loss += loss
            total_tokens += batch.ntokens
            tokens += batch.ntokens
            if i % 50 == 1:
                elapsed = time.time() - start
                print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                        (i, loss / batch.ntokens, tokens / elapsed))
                start = time.time()
                tokens = 0
        return total_loss / total_tokens
    

    8.3 batch_size_fn

    global max_src_in_batch, max_tgt_in_batch
    def batch_size_fn(new, count, sofar):
        "Keep augmenting batch and calculate total number of tokens + padding."
        global max_src_in_batch, max_tgt_in_batch
        if count == 1:
            max_src_in_batch = 0
            max_tgt_in_batch = 0
        max_src_in_batch = max(max_src_in_batch,  len(new.src))
        max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
        src_elements = count * max_src_in_batch
        tgt_elements = count * max_tgt_in_batch
        return max(src_elements, tgt_elements)
    

    9 真实世界的例子

    9.1 数据装载

    # For data loading.
    from torchtext import data, datasets
    
    if True:
        import spacy
        spacy_de = spacy.load('de')
        spacy_en = spacy.load('en')
    
        def tokenize_de(text):
            return [tok.text for tok in spacy_de.tokenizer(text)]
    
        def tokenize_en(text):
            return [tok.text for tok in spacy_en.tokenizer(text)]
    
        BOS_WORD = '<s>'
        EOS_WORD = '</s>'
        BLANK_WORD = "<blank>"
        SRC = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD)
        TGT = data.Field(tokenize=tokenize_en, init_token = BOS_WORD, 
                         eos_token = EOS_WORD, pad_token=BLANK_WORD)
    
        MAX_LEN = 100
        train, val, test = datasets.IWSLT.splits(
            exts=('.de', '.en'), fields=(SRC, TGT), 
            filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                len(vars(x)['trg']) <= MAX_LEN)
        MIN_FREQ = 2
        SRC.build_vocab(train.src, min_freq=MIN_FREQ)
        TGT.build_vocab(train.trg, min_freq=MIN_FREQ)
    

    9.2 迭代器

    class MyIterator(data.Iterator):
        def create_batches(self):
            if self.train:
                def pool(d, random_shuffler):
                    for p in data.batch(d, self.batch_size * 100):
                        p_batch = data.batch(
                            sorted(p, key=self.sort_key),
                            self.batch_size, self.batch_size_fn)
                        for b in random_shuffler(list(p_batch)):
                            yield b
                self.batches = pool(self.data(), self.random_shuffler)
                
            else:
                self.batches = []
                for b in data.batch(self.data(), self.batch_size,
                                              self.batch_size_fn):
                    self.batches.append(sorted(b, key=self.sort_key))
    
    def rebatch(pad_idx, batch):
        "Fix order in torchtext to match ours"
        src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
        return Batch(src, trg, pad_idx)
    

    9.3 多gpu计算

    # Skip if not interested in multigpu.
    class MultiGPULossCompute:
        "A multi-gpu loss compute and train function."
        def __init__(self, generator, criterion, devices, opt=None, chunk_size=5):
            # Send out to different gpus.
            self.generator = generator
            self.criterion = nn.parallel.replicate(criterion, 
                                                   devices=devices)
            self.opt = opt
            self.devices = devices
            self.chunk_size = chunk_size
            
        def __call__(self, out, targets, normalize):
            total = 0.0
            generator = nn.parallel.replicate(self.generator, 
                                                    devices=self.devices)
            out_scatter = nn.parallel.scatter(out, 
                                              target_gpus=self.devices)
            out_grad = [[] for _ in out_scatter]
            targets = nn.parallel.scatter(targets, 
                                          target_gpus=self.devices)
    
            # Divide generating into chunks.
            chunk_size = self.chunk_size
            for i in range(0, out_scatter[0].size(1), chunk_size):
                # Predict distributions
                out_column = [[Variable(o[:, i:i+chunk_size].data, 
                                        requires_grad=self.opt is not None)] 
                               for o in out_scatter]
                gen = nn.parallel.parallel_apply(generator, out_column)
    
                # Compute loss. 
                y = [(g.contiguous().view(-1, g.size(-1)), 
                      t[:, i:i+chunk_size].contiguous().view(-1)) 
                     for g, t in zip(gen, targets)]
                loss = nn.parallel.parallel_apply(self.criterion, y)
    
                # Sum and normalize loss
                l = nn.parallel.gather(loss, 
                                       target_device=self.devices[0])
                l = l.sum()[0] / normalize
                total += l.data[0]
    
                # Backprop loss to output of transformer
                if self.opt is not None:
                    l.backward()
                    for j, l in enumerate(loss):
                        out_grad[j].append(out_column[j][0].grad.data.clone())
    
            # Backprop all loss through transformer.            
            if self.opt is not None:
                out_grad = [Variable(torch.cat(og, dim=1)) for og in out_grad]
                o1 = out
                o2 = nn.parallel.gather(out_grad, 
                                        target_device=self.devices[0])
                o1.backward(gradient=o2)
                self.opt.step()
                self.opt.optimizer.zero_grad()
            return total * normalize
    
    # GPUs to use
    devices = [0, 1, 2, 3]
    if True:
        pad_idx = TGT.vocab.stoi["<blank>"]
        model = make_model(len(SRC.vocab), len(TGT.vocab), N=6)
        model.cuda()
        criterion = LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1)
        criterion.cuda()
        BATCH_SIZE = 12000
        train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=0,
                                repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                                batch_size_fn=batch_size_fn, train=True)
        valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=0,
                                repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                                batch_size_fn=batch_size_fn, train=False)
        model_par = nn.DataParallel(model, device_ids=devices)
    None
    

    9.4 训练

    if False:
        model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000,
                torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        for epoch in range(10):
            model_par.train()
            run_epoch((rebatch(pad_idx, b) for b in train_iter), 
                      model_par, 
                      MultiGPULossCompute(model.generator, criterion, 
                                          devices=devices, opt=model_opt))
            model_par.eval()
            loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), 
                              model_par, 
                              MultiGPULossCompute(model.generator, criterion, 
                              devices=devices, opt=None))
            print(loss)
    else:
        model = torch.load("iwslt.pt")
    
    for i, batch in enumerate(valid_iter):
        src = batch.src.transpose(0, 1)[:1]
        src_mask = (src != SRC.vocab.stoi["<blank>"]).unsqueeze(-2)
        out = greedy_decode(model, src, src_mask, 
                            max_len=60, start_symbol=TGT.vocab.stoi["<s>"])
        print("Translation:", end="\t")
        for i in range(1, out.size(1)):
            sym = TGT.vocab.itos[out[0, i]]
            if sym == "</s>": break
            print(sym, end =" ")
        print()
        print("Target:", end="\t")
        for i in range(1, batch.trg.size(0)):
            sym = TGT.vocab.itos[batch.trg.data[i, 0]]
            if sym == "</s>": break
            print(sym, end =" ")
        print()
        break
    

    9.5 额外组件 BPE, Search, Averaging

  • 相关阅读:
    「SELECT~FOR UPDATE NOWAIT」
    IT精英完美的七种生活方式
    ASP.NET下载CSV文件
    对一个Frame内控件的遍历
    .Net日期与时间的取得方法
    表的字段修改(SQL语句)
    谁能给我一些软件开发相关的名言警句
    LeetCode: Add two numbers
    LeetCode: 3Sum
    LeetCode: 4Sum
  • 原文地址:https://www.cnblogs.com/shouhuxianjian/p/16165451.html
Copyright © 2020-2023  润新知