• pytroch-Dataset/DataLoader/collate_fn/pad_sequence方法介绍


    DataSet和DataLoader参考:https://www.cnblogs.com/dogecheng/p/11930535.html

    pad_sequence参考:https://zhuanlan.zhihu.com/p/59772104

    为什么要使用这个类?

    DataLoader可以作为迭代器,使用多线程并行处理,加快训练。DataLoader需要传入Dataset对象。

    Dataset

    Dataset对象需要重定义两个方法:len() 和 getitem()

    例如:

    import torch
    from torch.utils.data import Dataset, DataLoader
    
    
    class MyDataset(Dataset):
        def __init__(self, path, is_train):
            self.is_train = is_train
            self.dataset = []
            with open(path, encoding='utf-8') as f:
                for line in f.readlines():
                    self.dataset.append(line.strip())
    
        def __len__(self):
            return len(self.dataset)
    
        def __getitem__(self, item):
            if self.is_train:
                text, label = self.dataset[item].split()
                text, label = list(map(int, text)), list(map(int, label))
                return {'text': torch.tensor(text), 'label': torch.tensor(label)}
            else:
                text = list(self.dataset[item])
                return {'text': torch.tensor(text)}
    
    train_data = MyDataset('train.txt', is_train=True)
    print(train_data[0])
    print(train_data[1])
    # {'text': tensor([11, 13, 10, 24, 34, 12]), 'label': tensor([0])}
    # {'text': tensor([12, 15, 21]), 'label': tensor([1])}
    

    DataLoader/collate_fun

    DataLoader可以设置参数:

    • batch_size、shuffle
    • num_workers:线程数,windows要设成0,否则容易出错
    • sampler:默认RandomSampler,即随机采样
    • collate_fun:自定义处理数据的函数
    from torch.nn.utils.rnn import pad_sequence
    
    def collate_train(batch_data):
        # 传入一个batch_size大小的数据
        # 后续将填充好的序列数据输入到RNN模型时需要使用pack_padded_sequence函数
        # pack_padded_sequence函数要求要按照序列的长度倒序排列,还需要知道每个句子长度
        batch_data.sort(key=lambda s: len(s['text']), reverse=True)
        text_batch = []
        label_batch = []
        text_len = []
        for data in batch_data:
            text_batch.append(data['text'])
            text_len.append(len(data['text']))
            label_batch.append(data['label'])
        # 注意text_batch内必须是tensor,才能使用pad_sequence
        text_batch = pad_sequence(text_batch, batch_first=True)
        return {'text': text_batch, 'label': label_batch, 'text_len': text_len}
    
    train_data = MyDataset('train.txt', is_train=True)
    train_loader = DataLoader(train_data, batch_size=2, num_workers=0, shuffle=True, collate_fn=collate_train)
    for batch in train_loader:
        seq, label, seq_len = batch
        print(seq)
    # tensor([[11, 13, 10, 24, 34, 12],
    #         [12, 15, 21, 0, 0, 0]])
    

    SubsetRandomSampler

    如果需要验证集,则使用SubsetRandomSampler:

    • 用Sampler需要将shuffle设为False
    from torch.utils.data.sampler import SubsetRandomSampler
    
    train_data = MyDataset('train.txt', is_train=True)
    n_train = len(train_data)
    split = n_train // 3
    indices = list(range(n_train))
    train_sampler = SubsetRandomSampler(indices[split:])
    valid_sampler = SubsetRandomSampler(indices[:split])
    train_loader = DataLoader(train_data, sampler=train_sampler, shuffle=False, batch)
    valid_loader = DataLoader(train_data, sampler=valid_sampler, shuffle=False, batch_size=2)
    

    pad_sequence/pack_padded_sequence/pad_packed_sequence

    注意:NLP任务都一般都需要截断或填充句子,此时需要三个函数。

    第一步:填充

    • pad_sequence(): 补齐句子,默认补0,可以自定义填补值

    • pad_sequence()已经在collate_train中设置了,其结果为:

    batch_seqs = tensor([[1, 2, 3, 4, 5, 6],
            		     [1, 2, 3, 0, 0, 0]])
    

    第二步:压缩

    • 若直接使用填充的0优化网络,浪费计算资源,需要使用pack_padded_sequence函数
    • pack_padded_sequence需要将batch内的序列从长到短排序,且知道每个序列的长度
    from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
    
    batch_lens = [6, 3]
    x = pack_padded_sequence(batch_seqs, lengths=batch_lens, batch_first=True)
    print(x)
    # PackedSequence(data=tensor([1, 1, 2, 2, 3, 3, 4, 5, 6]), 
    				batch_sizes=tensor([2, 2, 2, 1, 1, 1]), sorted_indices=None,
    				unsorted_indices=None)
    # 模型就会按照batch_sizes大小一步一步前向传播
    

    第三步:解压

    • 还原成原来格式:pad_packed_sequence
    x = pack_padded_sequence(batch_seqs, lengths=batch_lens, batch_first=True)
    x = pad_packed_sequence(x, batch_first=True)
    print(x)
    # (tensor([[1, 2, 3, 4, 5, 6],
    #          [1, 2, 3, 0, 0, 0]]), tensor([6, 3]))
    
  • 相关阅读:
    the configured user limit (128) on the number of inotify instances has been reached
    RabbitMQ Docker 单服务器集群
    webapi和GRPC性能对比
    camstart API 服务器负载均衡
    视图查询缺少值
    system.Data.Entity.Infrastructure.DbUpdateConcurrencyException: Store update, insert, or delete statement affected an unexpected number of rows (0) 问题
    WCF 基础连接已经关闭: 服务器关闭了本应保持活动状态的连接。
    优化sql用到的方法
    调用C++动态链接库出现错误
    ThoughtWorks.QRCode源码
  • 原文地址:https://www.cnblogs.com/mingriyingying/p/13387276.html
Copyright © 2020-2023  润新知