• torchvision


      torchvision中的datasets模块种包含了多种常用的分类数据集相关的下载、导入函数,如表格:

    数据集对应的类 描述
    datasets.MNIST() 手写字体数据集
    datasets.FashionMNIST() 衣服、鞋子、包等10类
    datasets.KMNIST() 一些文字的灰度数据
    datasets.CocoCaptions() 用于图像检测标注的MS COCO数据
    datasets.CocoDetection() 用于检测的MS COCO数据
    datasets.LSUN() 10个场景和20个目标的分类数据集
    datasets.CIFAR10() CIFAR10类数据集
    datasets.CIFAR100() CIFAR100类数据集
    datasets.STL10() 包含10类的分类数据集和大量的未标记数据
    datasets.ImageFolder() 定义一个数据加载器从文件种读取数据

    torchvision.transforms模块

    对应的类 描述
    transforms.Compose() 将多个transform组合起来使用
    transforms.Scale() 按照指定的图像尺寸对图像进行调整
    transforms.CenterCrop() 将图像进行中心切割,得到指定大小的图像
    transforms.RandomCrop() 切割中心点的位置随机选取
    transforms.RandomHorizontalFlip() 将图像进行随机水平翻转
    transforms.RandomSizedCrop() 将给定的图像随机切割,然后再变换给定大小
    transforms.Pad() 把图像所有的边用给定的pad value填充
    transforms.ToTensor() 把一个取值范围为[0,255]的PIL图像或形状为[H,W,C]的数组,转换成形状为[C,H,W],取值范围为[0,1.0],的张量(torch.FloatTensor)
    transforms.Normalize() 将给定的图像进行规范化操作
    transforms.Lambda(lambd) 使用lambd作为转化器,可自定义图像操作方式

    例如代码所示

    def ImgSplit(img_root=img_root0,batch_size=BTACH_SIZE,trainrate=0.8):
        # 数据加载及处理,对数据进行翻转,亮度,对比度等数据增广
        #print("图像预处理中。。。。。")
        transform = transforms.Compose([
            transforms.Resize(224),             #将图片按照比例缩放至224*224
            transforms.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),      #随机旋转
            torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
            transforms.ToTensor(),              #转为tensor
            transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
        ])
        all_data = torchvision.datasets.ImageFolder(
            root=img_root,
            transform=transform
        )
        
        train_data, vaild_data = torch.utils.data.random_split(all_data, [int(trainrate * len(all_data)),
                                                                          len(all_data) - int(trainrate * len(all_data))])
    
        train_set = torch.utils.data.DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True
        )
        test_set = torch.utils.data.DataLoader(
            vaild_data,
            batch_size=batch_size,
            shuffle=False
        )
        #print("图像预处完成。。。。。")
    View Code

    这里将图像路径为img_root0的数据集划分为80%的训练集和20%的测试集,每次放入训练的数据是BATCH_SIZE

  • 相关阅读:
    微软老将Philip Su的离职信:回首12年职场生涯的心得和随笔[转]
    <Programming Ruby 1.9 The Pragmatic Programmer>读书笔记
    Ruby:Update value on specific row but keep the headers
    解决ImportError: cannot import name webdriver
    Ruby几个相关目录
    《发财日记》处处都是名言警句
    软件研发管理最佳实践(20121020 深圳)
    中国过程改进年会会前培训:让敏捷落地! 软件研发管理最佳实践(2012530 北京)
    展示你的才华,成就你的事业!—— 疯狂讲师
    网络直播课程 体验极限编程(XP)
  • 原文地址:https://www.cnblogs.com/2020zxc/p/14629702.html
Copyright © 2020-2023  润新知