torchvision中的datasets模块种包含了多种常用的分类数据集相关的下载、导入函数,如表格:
数据集对应的类 | 描述 |
datasets.MNIST() | 手写字体数据集 |
datasets.FashionMNIST() | 衣服、鞋子、包等10类 |
datasets.KMNIST() | 一些文字的灰度数据 |
datasets.CocoCaptions() | 用于图像检测标注的MS COCO数据 |
datasets.CocoDetection() | 用于检测的MS COCO数据 |
datasets.LSUN() | 10个场景和20个目标的分类数据集 |
datasets.CIFAR10() | CIFAR10类数据集 |
datasets.CIFAR100() | CIFAR100类数据集 |
datasets.STL10() | 包含10类的分类数据集和大量的未标记数据 |
datasets.ImageFolder() | 定义一个数据加载器从文件种读取数据 |
torchvision.transforms模块
对应的类 | 描述 |
transforms.Compose() | 将多个transform组合起来使用 |
transforms.Scale() | 按照指定的图像尺寸对图像进行调整 |
transforms.CenterCrop() | 将图像进行中心切割,得到指定大小的图像 |
transforms.RandomCrop() | 切割中心点的位置随机选取 |
transforms.RandomHorizontalFlip() | 将图像进行随机水平翻转 |
transforms.RandomSizedCrop() | 将给定的图像随机切割,然后再变换给定大小 |
transforms.Pad() | 把图像所有的边用给定的pad value填充 |
transforms.ToTensor() | 把一个取值范围为[0,255]的PIL图像或形状为[H,W,C]的数组,转换成形状为[C,H,W],取值范围为[0,1.0],的张量(torch.FloatTensor) |
transforms.Normalize() | 将给定的图像进行规范化操作 |
transforms.Lambda(lambd) | 使用lambd作为转化器,可自定义图像操作方式 |
例如代码所示
def ImgSplit(img_root=img_root0,batch_size=BTACH_SIZE,trainrate=0.8): # 数据加载及处理,对数据进行翻转,亮度,对比度等数据增广 #print("图像预处理中。。。。。") transform = transforms.Compose([ transforms.Resize(224), #将图片按照比例缩放至224*224 transforms.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.8, 1.0)), transforms.RandomHorizontalFlip(), #随机旋转 torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0), transforms.ToTensor(), #转为tensor transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) ]) all_data = torchvision.datasets.ImageFolder( root=img_root, transform=transform ) train_data, vaild_data = torch.utils.data.random_split(all_data, [int(trainrate * len(all_data)), len(all_data) - int(trainrate * len(all_data))]) train_set = torch.utils.data.DataLoader( train_data, batch_size=batch_size, shuffle=True ) test_set = torch.utils.data.DataLoader( vaild_data, batch_size=batch_size, shuffle=False ) #print("图像预处完成。。。。。")
这里将图像路径为img_root0的数据集划分为80%的训练集和20%的测试集,每次放入训练的数据是BATCH_SIZE