• 04 Transformer 中的位置编码的 Pytorch 实现


    1:10 点赞

    16:00

    img

    我爱你

    你爱我

    1401

    img
    class PositionalEncoding(nn.Module):
    
        def __init__(self, dim, dropout, max_len=5000):
            super(PositionalEncoding, self).__init__()
    
            if dim % 2 != 0:
                raise ValueError("Cannot use sin/cos positional encoding with "
                                 "odd dim (got dim={:d})".format(dim))
    
            """
            构建位置编码pe
            pe公式为:
            PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})
            """
            pe = torch.zeros(max_len, dim)  # max_len 是解码器生成句子的最长的长度,假设是 10
            position = torch.arange(0, max_len).unsqueeze(1)
            div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                                  -(math.log(10000.0) / dim)))
    
    
            pe[:, 0::2] = torch.sin(position.float() * div_term)
            pe[:, 1::2] = torch.cos(position.float() * div_term)
            pe = pe.unsqueeze(1)
            self.register_buffer('pe', pe)
            self.drop_out = nn.Dropout(p=dropout)
            self.dim = dim
    
        def forward(self, emb, step=None):
    
            emb = emb * math.sqrt(self.dim)
    
            if step is None:
                emb = emb + self.pe[:emb.size(0)]
            else:
                emb = emb + self.pe[step]
            emb = self.drop_out(emb)
            return emb
    
    
  • 相关阅读:
    用纯 javascript 提高博客访问量
    大龄程序员交流
    Git 本地仓库操作基本命令
    SoapUI登录测试(2)-- 断言
    SoapUI测试登录
    deleteMany is not a function
    jQuery contextMenu使用
    安装MongoDB -- Windows平台
    TortoiseGit 图标不显示
    C#的自定义滚动条
  • 原文地址:https://www.cnblogs.com/nickchen121/p/16529997.html
Copyright © 2020-2023  润新知