• [深度学习] pytorch利用Datasets和DataLoader读取数据


     

    本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码。所有逻辑需按自己需求另外实现:

     

    一、分析DataLoader

    train_loader = DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)

    datasets.MNIST()是一个torch.utils.data.Datasets对象,batch_size表示我们定义的batch大小(即每轮训练使用的批大小),shuffle表示是否打乱数据顺序(对于整个datasets里包含的所有数据)。

    对于batch_size和shuffle都是根据业务需求来认为指定的,不做过多说明。

    对于Datasets对象来说,我们可以根据自己的数据类型来自定义,自己定义一个类,继承Datasets类。

    二、分析Datasets类

    class Dataset(object):
        """An abstract class representing a Dataset.
    
        All other datasets should subclass it. All subclasses should override
        ``__len__``, that provides the size of the dataset, and ``__getitem__``,
        supporting integer indexing in range from 0 to len(self) exclusive.
        """
    
        def __getitem__(self, index):
            raise NotImplementedError
    
        def __len__(self):
            raise NotImplementedError
    
        def __add__(self, other):
            return ConcatDataset([self, other])

    上述代码是pytorch中Datasets的源码,注意成员方法__getitem__和__len__都是未实现的。我们要实现自定义Datasets类来完成数据的读取,则只需要完成这两个成员方法的重写。

      首先,__getitem__()方法用来从datasets中读取一条数据,这条数据包含训练图片(已CV距离)和标签,参数index表示图片和标签在总数据集中的Index。

      __len__()方法返回数据集的总长度(训练集的总数)。

    三、简单实现MyDatasets类

    # -*- coding:utf-8 -*-
    __author__ = 'Leo.Z'
    
    import os
    
    from torch.utils.data import Dataset
    from torch.utils.data import DataLoader
    import matplotlib.image as mpimg
    
    
    # 对所有图片生成path-label map.txt
    def generate_map(root_dir):
        current_path = os.path.abspath(__file__)
        father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".")
    
        with open(root_dir + 'map.txt', 'w') as wfp:
            for idx in range(10):
                subdir = os.path.join(root_dir, '%d/' % idx)
                for file_name in os.listdir(subdir):
                    abs_name = os.path.join(father_path, subdir, file_name)
                    linux_abs_name = abs_name.replace("\", '/')
                    wfp.write('{file_dir} {label}
    '.format(file_dir=linux_abs_name, label=idx))
    
    
    # 实现MyDatasets类
    class MyDatasets(Dataset):
    
        def __init__(self, dir):
            # 获取数据存放的dir
            # 例如d:/images/
            self.data_dir = dir
            # 用于存放(image,label) tuple的list,存放的数据例如(d:/image/1.png,4)
            self.image_target_list = []
            # 从dir--label的map文件中将所有的tuple对读取到image_target_list中
            # map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路径最好是绝对路径
            with open(os.path.join(dir, 'map.txt'), 'r') as fp:
                content = fp.readlines()
                str_list = [s.rstrip().split() for s in content]
                # 将所有图片的dir--label对都放入列表,如果要执行多个epoch,可以在这里多复制几遍,然后统一shuffle比较好
                self.image_target_list = [(x[0], int(x[1])) for x in str_list]
    
        def __getitem__(self, index):
            image_label_pair = self.image_target_list[index]
            # 按path读取图片数据,并转换为图片格式例如[3,32,32]
            img = mpimg.imread(image_label_pair[0])
            return img, image_label_pair[1]
    
        def __len__(self):
            return len(self.image_target_list)
    
    
    if __name__ == '__main__':
        # 生成map.txt
        # generate_map('train/')
    
        train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True)
    
        for step in range(20000):
            for idx, (img, label) in enumerate(train_loader):
                print(img.shape)
                print(label.shape)

    上述代码简要说明了利用Datasets类和DataLoader类来读取数据,本例用的是图片原始数据,大概的结构如下:

    如果使用其他形式的数据,例如二进制文件,则需要字节读取文件,分割成每一张图片和label,然后从__getitem__中返回就可以了。例如cifar-10数据,我们只需要在__getitem__方法中,按index来读取对应位置的字节,然后转换为label和img,并返回。在__len__中返回cifar-10训练集的总样本数。DataLoader就可以根据我们提供的index,len以及batch_size,shuffle来返回相应的batch数据和label。

  • 相关阅读:
    后端——框架——视图层框架——spring_mvc——《官网》阅读笔记——第一章节26(过滤器,ShallowEtagHeaderFilter)
    后端——框架——视图层框架——spring_mvc——《官网》阅读笔记——第一章节27(过滤器,CorsFilter)
    后端——框架——视图层框架——spring_mvc——《官网》阅读笔记——第一章节28(过滤器,其他Filter)
    后端——框架——视图层框架——spring_mvc——《官网》阅读笔记——第一章节29(注解,Controller类注解)
    后端——框架——视图层框架——spring_mvc——《官网》阅读笔记——第一章节30(注解,Handler方法注解)
    任务日历关联(Project)
    新建日历(Project)
    例外日期(Project)
    自定义日历(Project)
    日历的种类(Project)
  • 原文地址:https://www.cnblogs.com/leokale-zz/p/11275800.html
Copyright © 2020-2023  润新知