torchvision包服务于PyTorch框架,包括了计算机视觉中一些流行的数据集、网络模型以及常见的图片变换方法,主要由以下几部分构成:
torchvision.datasets
: 一些加载数据的函数和常用的数据集接口
torchvision.models
:包含常用的模型结构(含预训练模型)
torchvision.transforms
:常见的图片变换,如裁剪、旋转等
torchvision.utils
:其它一些有用的方法
利用torchvision.datasets
接口可以得到许多种类的数据集,可以传给DataLoader
做进一步处理,使用这些数据集的API都差不多,以使用MNIST数据集为例,来看一下使用的流程
导入需要的包或模块:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as dataset
MNIST的文档描述:
root
-就是存放数据的文件路径
train
-如果为True
,那么得到训练集,否则得到测试集
download
-True
表示需要从网上下载这个数据集,如果已经下载了,那么直接加载
transform
-对图片进行一些处理
示例:
mnist_train_data = datasets.MNIST('../MNIST', train=True, download=True, transform =
tranforms.Compose([
torch.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
对图片的tranforms
有很多种,根据需要选择,他们可以利用torchvision.tranforms.Compose(transforms)
把这些操作串在一起,就像上面例子一样,形式是:
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
更多的细节以后遇到再加以补充。。。
参考: