• 关于torchvision.datasets.CIFAR10


    在Pytorch0.4版本的DARTS代码里,有一行代码是

    trn_data = datasets.CIFAR10(root=data_path, train=True, download=False, transform=train_transform)
    shape = trn_data.train_data.shape

    在1.2及以上版本里,查看源码可知,CIFAR10这个类已经没有train_data这个属性了,取而代之的是data,因此要把第二行改成

    shape = trn_data.data.shape

    datasets.CIFAR10源码如下:

    from __future__ import print_function
    from PIL import Image
    import os
    import os.path
    import numpy as np
    import sys
    
    if sys.version_info[0] == 2:
        import cPickle as pickle
    else:
        import pickle
    
    from .vision import VisionDataset
    from .utils import check_integrity, download_and_extract_archive
    
    
    [docs]class CIFAR10(VisionDataset):
        """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    
        Args:
            root (string): Root directory of dataset where directory
                ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
            train (bool, optional): If True, creates dataset from training set, otherwise
                creates from test set.
            transform (callable, optional): A function/transform that takes in an PIL image
                and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
                target and transforms it.
            download (bool, optional): If true, downloads the dataset from the internet and
                puts it in root directory. If dataset is already downloaded, it is not
                downloaded again.
    
        """
        base_folder = 'cifar-10-batches-py'
        url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        filename = "cifar-10-python.tar.gz"
        tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
        train_list = [
            ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
            ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
            ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
            ['data_batch_4', '634d18415352ddfa80567beed471001a'],
            ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
        ]
    
        test_list = [
            ['test_batch', '40351d587109b95175f43aff81a1287e'],
        ]
        meta = {
            'filename': 'batches.meta',
            'key': 'label_names',
            'md5': '5ff9c542aee3614f3951f8cda6e48888',
        }
    
        def __init__(self, root, train=True, transform=None, target_transform=None,
                     download=False):
    
            super(CIFAR10, self).__init__(root, transform=transform,
                                          target_transform=target_transform)
    
            self.train = train  # training set or test set
    
            if download:
                self.download()
    
            if not self._check_integrity():
                raise RuntimeError('Dataset not found or corrupted.' +
                                   ' You can use download=True to download it')
    
            if self.train:
                downloaded_list = self.train_list
            else:
                downloaded_list = self.test_list
    
            self.data = []
            self.targets = []
    
            # now load the picked numpy arrays
            for file_name, checksum in downloaded_list:
                file_path = os.path.join(self.root, self.base_folder, file_name)
                with open(file_path, 'rb') as f:
                    if sys.version_info[0] == 2:
                        entry = pickle.load(f)
                    else:
                        entry = pickle.load(f, encoding='latin1')
                    self.data.append(entry['data'])
                    if 'labels' in entry:
                        self.targets.extend(entry['labels'])
                    else:
                        self.targets.extend(entry['fine_labels'])
    
            self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
            self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
    
            self._load_meta()
  • 相关阅读:
    UVA 10600 ACM Contest and Blackout(次小生成树)
    UVA 10369
    UVA Live 6437 Power Plant 最小生成树
    UVA 1151 Buy or Build MST(最小生成树)
    UVA 1395 Slim Span 最小生成树
    POJ 1679 The Unique MST 次小生成树
    POJ 1789 Truck History 最小生成树
    POJ 1258 Agri-Net 最小生成树
    ubuntu 用法
    ubuntu 搭建ftp服务器,可以通过浏览器访问,filezilla上传文件等功能
  • 原文地址:https://www.cnblogs.com/yqpy/p/11831717.html
Copyright © 2020-2023  润新知