• PyTorch 之 Datasets


    实现一个定制的 Dataset 类

    Dataset 类是 PyTorch 图像数据集中最为重要的一个类,也是 PyTorch 中所有数据集加载类中应该继承的父类。其中,父类的两个私有成员函数必须被重载。

    • getitem(self, index) # 支持数据集索引的函数
    • len(self) # 返回数据集的大小

    Datasets 的框架:

    class CustomDataset(data.Dataset): # 需要继承 data.Dataset
        def __init__(self):
            # TODO
            # Initialize file path or list of file names.
            pass
            
        def __getitem__(self, index):
            # TODO
            # 1. 从文件中读取指定 index 的数据(例:使用 numpy.fromfile, PIL.Image.open)
            # 2. 预处理读取的数据(例:torchvision.Transform)
            # 3. 返回数据对(例:图像和对应标签)
            pass
        
        def __len__(self):
            # TODO
            # You should change 0 to the total size of your dataset.
            return 0
    

    举例:

    class MyDataset(Dataset):
        """
         root: 图像存放地址根路径
         augment:是否需要图像增强
        """
        
        def __init__(self, root, augment=None):
            # 这个 list 存放所有图像的地址
            self.image_files = np.array([
                x.path for x in os.scandir(root)
                if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")
            ])
            self.augment = augment
           
        
        def __getitem__(self, index):
            if self.augment:
                image = open_image(self.image_files[index])   # 这里的 open_image 是读取图像的函数,可以用 PIL 或者 OpenCV 等库进行读取
                image = self.augment(image)	  # 这里对图像进行了数据增强
                return to_tensor(image)	      # PyTorch 中得到的图像必须是 tensor
            else:
                image = open_image(self.image_files[index])
                return to_tensor(image)
    

    下面是官方 MNIST 的例子:

    class MNIST(data.Dataset):
        """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
        Args:
            root (string): Root directory of dataset where ``processed/training.pt``
                and  ``processed/test.pt`` exist.
            train (bool, optional): If True, creates dataset from ``training.pt``,
                otherwise from ``test.pt``.
            download (bool, optional): If true, downloads the dataset from the internet and
                puts it in root directory. If dataset is already downloaded, it is not
                downloaded again.
            transform (callable, optional): A function/transform that  takes in an PIL image
                and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
                target and transforms it.
        """
        urls = [
            'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
            'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
            'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
            'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
        ]
        raw_folder = 'raw'
        processed_folder = 'processed'
        training_file = 'training.pt'
        test_file = 'test.pt'
        classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
                   '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
        class_to_idx = {_class: i for i, _class in enumerate(classes)}
    
        @property
        def targets(self):
            if self.train:
                return self.train_labels
            else:
                return self.test_labels
    
        def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
            self.root = os.path.expanduser(root)
            self.transform = transform
            self.target_transform = target_transform
            self.train = train  # training set or test set
    
            if download:
                self.download()
    
            if not self._check_exists():
                raise RuntimeError('Dataset not found.' +
                                   ' You can use download=True to download it')
    
            if self.train:
                self.train_data, self.train_labels = torch.load(
                    os.path.join(self.root, self.processed_folder, self.training_file))
            else:
                self.test_data, self.test_labels = torch.load(
                    os.path.join(self.root, self.processed_folder, self.test_file))
    
        def __getitem__(self, index):
            """
            Args:
                index (int): Index
            Returns:
                tuple: (image, target) where target is index of the target class.
            """
            if self.train:
                img, target = self.train_data[index], self.train_labels[index]
            else:
                img, target = self.test_data[index], self.test_labels[index]
    
            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(img.numpy(), mode='L')
    
            if self.transform is not None:
                img = self.transform(img)
    
            if self.target_transform is not None:
                target = self.target_transform(target)
    
            return img, target
    
        def __len__(self):
            if self.train:
                return len(self.train_data)
            else:
                return len(self.test_data)
    
        def _check_exists(self):
            return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and 
                os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
    
        def download(self):
            """Download the MNIST data if it doesn't exist in processed_folder already."""
            from six.moves import urllib
            import gzip
    
            if self._check_exists():
                return
    
            # download files
            try:
                os.makedirs(os.path.join(self.root, self.raw_folder))
                os.makedirs(os.path.join(self.root, self.processed_folder))
            except OSError as e:
                if e.errno == errno.EEXIST:
                    pass
                else:
                    raise
    
            for url in self.urls:
                print('Downloading ' + url)
                data = urllib.request.urlopen(url)
                filename = url.rpartition('/')[2]
                file_path = os.path.join(self.root, self.raw_folder, filename)
                with open(file_path, 'wb') as f:
                    f.write(data.read())
                with open(file_path.replace('.gz', ''), 'wb') as out_f, 
                        gzip.GzipFile(file_path) as zip_f:
                    out_f.write(zip_f.read())
                os.unlink(file_path)
    
            # process and save as torch files
            print('Processing...')
    
            training_set = (
                read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
                read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
            )
            test_set = (
                read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
                read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
            )
            with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
                torch.save(training_set, f)
            with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
                torch.save(test_set, f)
    
            print('Done!')
    
        def __repr__(self):
            fmt_str = 'Dataset ' + self.__class__.__name__ + '
    '
            fmt_str += '    Number of datapoints: {}
    '.format(self.__len__())
            tmp = 'train' if self.train is True else 'test'
            fmt_str += '    Split: {}
    '.format(tmp)
            fmt_str += '    Root Location: {}
    '.format(self.root)
            tmp = '    Transforms (if any): '
            fmt_str += '{0}{1}
    '.format(tmp, self.transform.__repr__().replace('
    ', '
    ' + ' ' * len(tmp)))
            tmp = '    Target Transforms (if any): '
            fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('
    ', '
    ' + ' ' * len(tmp)))
            return fmt_str
    
    
  • 相关阅读:
    Mysql 创建联合主键
    Shell中的while循环
    shell 日期加减运算
    PHP日期格式转时间戳
    Uber 叫车时,弹出以下代码导致无法打车(An email confirmation has been sent to...),解决办法
    如何让Table显示滚动条
    mySQL中replace的用法
    打豪车应用:uber详细攻略(附100元优惠码)
    svn 命令行创建和删除 分支和tags
    php ob_start()、ob_end_flush和ob_end_clean()多级缓冲
  • 原文地址:https://www.cnblogs.com/xxxxxxxxx/p/11429051.html
Copyright © 2020-2023  润新知