• little_by_little_2 为一个数据集创建一个dataset类。(基于pytorch)


    little_by_little_2 为一个数据集创建一个dataset类。(基于pytorch)

    前言

    最近一段时间陷入了焦虑,迷茫之中最终获得了救赎。不想提及。

    任务

    为一个分类100元和1元的数据集创建一个pytorch.dataset,以便dataloader来读取

    源代码

    import os
    import random
    from PIL import Image
    from torch.utils.data import Dataset
    
    random.seed(1)
    rmb_label = {"1": 0, "100": 1}
    
    #1
    class RMBDataset(Dataset):
        def __init__(self, data_dir, transform=None):
            """
            rmb面额分类任务的Dataset
            :param data_dir: str, 数据集所在路径
            :param transform: torch.transform,数据预处理
            """
            self.label_name = {"1": 0, "100": 1}
            self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
            self.transform = transform
    #2
        def __getitem__(self, index):
            path_img, label = self.data_info[index]
            img = Image.open(path_img).convert('RGB')     # 0~255
    
            if self.transform is not None:
                img = self.transform(img)   # 在这里做transform,转为tensor等等
    
            return img, label
    
        def __len__(self):
            return len(self.data_info)
    #3
        @staticmethod
        def get_img_info(data_dir):
            data_info = list()
            for root, dirs, _ in os.walk(data_dir):
                # 遍历类别
                for sub_dir in dirs:
                    img_names = os.listdir(os.path.join(root, sub_dir))
                    img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
    
                    # 遍历图片
                    for i in range(len(img_names)):
                        img_name = img_names[i]
                        path_img = os.path.join(root, sub_dir, img_name)
                        label = rmb_label[sub_dir]
                        data_info.append((path_img, int(label)))
    
            return data_info
    
    

    解读

    #1部分

    class RMBDataset(Dataset):
        def __init__(self, data_dir, transform=None):
            """
            rmb面额分类任务的Dataset
            :param data_dir: str, 数据集所在路径
            :param transform: torch.transform,数据预处理
            """
            self.label_name = {"1": 0, "100": 1}
            self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
            self.transform = transform
    

    初始化数据不多作赘述。

    #2部分

        def __getitem__(self, index):
            path_img, label = self.data_info[index]
            img = Image.open(path_img).convert('RGB')     # 0~255
    
            if self.transform is not None:
                img = self.transform(img)   # 在这里做transform,转为tensor等等
    
            return img, label
    
        def __len__(self):
            return len(self.data_info)
    
    • 为什么要在_get_ item 里面定义?因为pytorch中用dataloader类调用dataset类的时候是这样子的:

    • path_img, label = self.data_info[index] 接收数据的数据以及标签

    • img = Image.open(path_img).convert('RGB') # 0~255 将img转换成三通道模式

    •     if self.transform is not None:
              img = self.transform(img)   # 在这里做transform,转为tensor等等
      

    判断是否传入了transform,若传入了transform则进行transform.compounds里面的transform变换.

    • return img, label 返回数据及标签

    # 3 部分

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
    
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))
    
        return data_info
    
    • 此函数的作用,得到路径内所有图片的数据,并打上label
    • for root, dirs, _ in os.walk(data_dir): 此处涉及到os.walk函数,
    def walk(top: T,
     topdown: bool = True,
     onerror: Optional[(Exception) -> None] = None,
     followlinks: bool = False) -> Iterator[Tuple[T, List[T], List[T]]]
     top -- 是你所要遍历的目录的地址, 
     return--返回的是一个三元组(root,dirs,files)。
    
        root 所指的是当前正在遍历的这个文件夹的本身的地址
        dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
        files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
                                       
    topdown --可选,为 True,则优先遍历 top 目录,否则优先遍历 top 的子目录(默认为开启)。如果 topdown 参数为 True,walk 会遍历top文件夹,与top 文件夹中每一个子目录。
    
    onerror -- 可选,需要一个 callable 对象,当 walk 需要异常时,会调用。
    
    followlinks -- 可选,如果为 True,则会遍历目录下的快捷方式(linux 下是软连接 symbolic link )实际所指的目录(默认关闭),如果为 False,则优先遍历 top 的子目录。
    
    • for sub_dir in dirs:
                  img_names = os.listdir(os.path.join(root, sub_dir))
                  img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
      

      先解释一下目录结构:

    image-20200416133247390

    1和100里面放着1和100元的图片.

    img_names = os.listdir(os.path.join(root, sub_dir)) 提取出.../1

    img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) 将.../1 下面所有的以.jpg结尾的文件名提取出来返回一个list,也就是说此时img_names成了一个list里面装满了.../1目录下所有的图片名字

    •         for i in range(len(img_names)):
                  img_name = img_names[i]
                  path_img = os.path.join(root, sub_dir, img_name)
                  label = rmb_label[sub_dir]
                  data_info.append((path_img, int(label)))
      

      这个函数主要作用是提取出img_names里面所有图片的路径以及label其中值得一提的是label = rmb_label[sub_dir] 由于本身文件夹的名字就是label所以提取label的方法就是提取文件夹的名字.

      最后返回一个data_info list 里面每个元素为元组形式(img_path,label).

  • 相关阅读:
    高性能MySQL笔记(第十一章 可扩展的MySQL)01
    高性能MySQL笔记(第十章 复制)02
    高性能MySQL笔记(第十章 复制)01
    高性能MySQL笔记(第六章 查询性能优化) 02
    高性能MySQL笔记(一个奇怪的问题)
    高性能MySQL笔记(第六章 查询性能优化) 01
    高性能MySQL笔记(第五章 创建高性能的索引) 02
    高性能MySQL笔记(第五章 创建高性能的索引) 01
    [Luogu] P1438 无聊的数列 | 线段树简单题
    [UCF HSPT 2021] Sharon’s Sausages | 思维 暴力
  • 原文地址:https://www.cnblogs.com/negu/p/12712337.html
Copyright © 2020-2023  润新知