• [TORCH] pack_padded_sequence 和 pad_packed_sequence 的使用 (2020版本)


    内容简介

    本文主要是通过代码的方式展示pytorch的pack和pad函数。
    找到的两个可以参考的靠谱网站(不是CSDN的奇怪东西):
    理论链接,建议直接看图
    实践链接,直接看代码

    使用的代码

    from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
    from torch.nn import utils as nn_utils
    import torch.nn.functional as F
    import torch
    
    # seq example
    # batch的尺寸是5,假设我们有五句话,每句话有不定长的词汇
    # 这里只假设每个词汇的feature是一维的
    batch_size = 5
    a = torch.tensor([1,2])
    b = torch.tensor([1,2,3])
    c = torch.tensor([1,2,3,4])
    d = torch.tensor([1])
    e = torch.tensor([1,2,3,4,5,6])
    
    
    # general setting
    # 提取五个句子的有效内容的长度
    # 并且提取最大句子的长度
    seq_lens = []
    for i in [a,b,c,d,e]:
        seq_lens.append(len(i))
    max_len = max(seq_lens)
    
    
    # Zero padding
    # 通过加入0pad,让他们的长度相等,这个长度是最长句子的长度
    a = F.pad(a,(0,max_len-len(a))) # 最低维度,前面增加0个,后面增加max-len(a)个
    b = F.pad(b,(0,max_len-len(b)))
    c = F.pad(c,(0,max_len-len(c)))
    d = F.pad(d,(0,max_len-len(d)))
    e = F.pad(e,(0,max_len-len(e)))
    print("在a句子经过pad填充以后:
    {}
    ".format(a))
    
    
    # merge the seq
    seq = torch.cat((a,b,c,d,e),0).view(-1,max_len) 
    print("所有句子融合以后可以获得整个矩阵:
    {}
    ".format(seq))
    
    
    # Pack
    # 1. input size 可以是(T×B×* ) = (最长序列长度T,batch size B,任意维度*)
    # 2. input size 可以是(B×T×*), 如果batch_first=True的话
    # 这里我们选择 batch 在前,所以是2
    packed_seq = pack_padded_sequence(seq, seq_lens, batch_first=True, enforce_sorted=False)
    print('经过了 pack_padded_sequence 处理:
    {}
    '.format(packed_seq))
    
    
    # Unpack
    unpacked_seq, unpacked_lens = pad_packed_sequence(packed_seq, batch_first=True)
    print('Unpack还原的结果:
    {}
    '.format(unpacked_seq))
    print('同时返回seq的length:
    {}
    '.format(unpacked_lens))
    

    代码运行的结果

    • 在a句子经过pad填充以后:
      tensor([1, 2, 0, 0, 0, 0])

    • 所有句子融合以后可以获得整个矩阵:
      tensor([[1, 2, 0, 0, 0, 0],
      [1, 2, 3, 0, 0, 0],
      [1, 2, 3, 4, 0, 0],
      [1, 0, 0, 0, 0, 0],
      [1, 2, 3, 4, 5, 6]])

    • 经过了 pack_padded_sequence 处理:
      PackedSequence(
      data=tensor([1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 6]),
      batch_sizes=tensor([5, 4, 3, 2, 1, 1]),
      sorted_indices=tensor([4, 2, 1, 0, 3]),
      unsorted_indices=tensor([3, 2, 1, 4, 0]))

    • Unpack还原的结果:
      tensor([[1, 2, 0, 0, 0, 0],
      [1, 2, 3, 0, 0, 0],
      [1, 2, 3, 4, 0, 0],
      [1, 0, 0, 0, 0, 0],
      [1, 2, 3, 4, 5, 6]])

    • 同时返回seq的length:
      tensor([2, 3, 4, 1, 6])

  • 相关阅读:
    linux文件上传
    ios base64图片上传失败问题
    ERROR 1267 (HY000): Illegal mix of collations (utf8_general_ci,IMPLICIT) and (utf8_unicode_ci,IMPLICIT) for operation '='
    配置SQL Server 2012 AlwaysOn ——step3 配置数据库
    配置SQL Server 2012 AlwaysOn ——step2 建立群集
    配置SQL Server 2012 AlwaysOn ——step1 建立AD域及DNS配置
    适应多场景应用的web系统架构探讨
    住院病案首页数据填写质量规范
    病案首页规范
    vs2015离线使用nuget
  • 原文地址:https://www.cnblogs.com/kykai/p/14033580.html
Copyright © 2020-2023  润新知