• pytorch中DataSet和DataLoader的使用详解


    1. 首先导入需要用到的包

    from torch.utils.data import DataLoader,Dataset
    

    2. 自定义Dataset

    一般情况下我们使用Dataset,需要自定义一个类来继承Dataset,然后实现__getitem__()方法和__len__()方法
    使用示例如下所示:

    import torch
    a = [[1,2,3,4],[4,5,6,7,9],[6,7,8,9,4,5],[4,3,2],[8,7,5,4],[4,8,7,1]]
    b = [1,2,3,4,5,6]
    
    class mydataset(Dataset):
        def __init__(self,x,y):
            self.feature = x
            self.label = y
        
        def __getitem__(self,item):
            return torch.tensor(self.feature[item]),self.label[item]   #根据需要进行设置
    
        def __len__(self):
            return len(self.feature)
    
    dataset = mydataset(a,b)
    
    print(dataset[0])
    

    程序运行结果如下所示:

    (tensor([1, 2, 3, 4]), 1)
    

    3. 创建DataLoader

    DataLoader需要传入几个参数,先看一下官方的定义:

    常用到的几个参数解释如下:

    # dataset:数据集,传入我们刚才创建的数据集即可;
    # batch_size:每个batch的大小
    # collate_fn:按照定义函数的方式进行取数据
    # shuffle:是否将数据集中的数据进行打乱
    

    使用示例如下所示:

    def fun(x):                                                    # 根据自己的需求定义dataloader返回数据的格式
        x.sort(key=lambda data:len(data[0]),reverse=True)
        # print(x)
        feature = []
        label = []
        length = []
        for i in x:
            feature.append(i[0])
            label.append(i[1])
            length.append(len(i[0]))
        # feature = pad_sequence(feature,batch_first=True,padding_value=-1)     # 可以适当的进行补齐操作
        return feature,label,length
    
    
    dataloader = DataLoader(dataset,batch_size=2,collate_fn=fun)    # 定义DataLoader
    
    for x,y,length in dataloader:
        print(x,y,length)
        print('------------------')
    

    程序运行结果如下所示:

    [tensor([4, 5, 6, 7, 9]), tensor([1, 2, 3, 4])] [2, 1] [5, 4]
    ------------------
    [tensor([6, 7, 8, 9, 4, 5]), tensor([4, 3, 2])] [3, 4] [6, 3]
    ------------------
    [tensor([8, 7, 5, 4]), tensor([4, 8, 7, 1])] [5, 6] [4, 4]
    
  • 相关阅读:
    RT-Thread代码启动过程与$Sub$ $main、$Super$ $main
    软件开源许可证
    git回退到历史版本以及再滚回去
    GMT、UTC、UNIX时间戳、时区
    sprintf的使用
    C# Json 和对象的相互转换
    获取指定年份/月份的周六周天 + 标记指定日期(加粗)
    Winform 窗体实现圆角展示
    VS2012统计代码量
    C# Winform 中使用FTP实现软件自动更新功能
  • 原文地址:https://www.cnblogs.com/noob-l/p/14674563.html
Copyright © 2020-2023  润新知