0.模块
torchvision有4个功能模块:model、datasets、transforms和utils。利用datasets可以下载一些经典数据集,本次笔记主要记录如何使用datasets的ImageFolder处理自定义数据集,以及如何使用transforms对源数据进行预处理、增强等。
1.transforms
transforms提供了对PIL Image对象和Tensor对象的常用操作。
1)对PIL Image的常见操作如下。
Scale/Resize:调整尺寸,长宽比保持不变。
CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop和RandomCrop在crop时是固定size,RandomResizedCrop则是random size的crop。
Pad:填充。
ToTensor:把一个取值范围是[0,255]的PIL.Image转换成Tensor。形状为(H,W,C)的Numpy.ndarray转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloatTensor。
RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5。
RandomVerticalFlip:图像随机垂直翻转。
ColorJitter:修改亮度、对比度和饱和度。
2)对Tensor的常见操作如下。
Normalize:标准化,即,减均值,除以标准差。
ToPILImage:将Tensor转为PIL Image。
如果要对数据集进行多个操作,可通过Compose将这些操作像管道一样拼接起来,类似于nn.Sequential。以下为示例代码:
这个东西会被送入你自定义的Dataset中!
transforms.Compose([
# 将给定的PIL.Image进行中心切割,得到给定的size
# size可以是tuple,(target_height, target_width)
# size也可以是一个Integer, 切出来一个正方形
transform.CenterCrop(10)
# 切割中心点的位置随机选取
transforms.RandomCrop(20, padding=0)
# 将一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray
# 转换为形状为(C,H,W),取值范围是[0,1]的torch.FloatTensor
transforms.ToTensor()
# 规范化到[-1, -1]
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5,0.5,0.5))
])
2.datasets.ImageFolder
当文件依据标签处于不同文件下时,如:
我们可以利用torchvision.datasets.ImageFolder来直接构造出dataset
loader = datasets.ImageFolder(path)
loader = data.DataLoader(dataset)
ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。
下面我们利用ImageFolder读取不同目录下的图片数据,然后使用transforms进行图像预处理,预处理有多个,我们用compose把这些操作拼接在一起。然后使用DataLoader加载。对处理后的数据用torchvision.utils中的save_image保存为一个png格式文件,然后用Image.open打开该png文件,详细代码如下:
from torchvision import transforms, utils
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
my_trans = transforms.Compose([
transforms.RandomResizedCrop(224), #将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小
transforms.RandomHorizontalFlip(), #图像水平翻转
transforms.ToTensor()
])
train_data = datasets.ImageFolder(r'./data/torchvision_data', transform = my_trans)
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
for i_batch, img in enumerate(train_loader):
if i_batch == 0:
print(img[1])
fig = plt.figure()
grid = utils.make_grid(img[0])
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()
utils.save_image(grid,'test02.png')
break
这里我建立一个torchvision_data文件夹,把不同类型的图片放在不同的子文件夹下。
运行结果为:
[参考](https://blog.csdn.net/qq_39610915/category_10487496.html)