一、项目结构
二、代码
1 data_loader = torch.utils.data.DataLoader(
2 torchvision.datasets.ImageFolder('traing_dataset',
3 transform=torchvision.transforms.Compose([
4 torchvision.transforms.Resize([28, 28]), # 裁剪图片
5 torchvision.transforms.Grayscale(1), # 单通道
6 torchvision.transforms.ToTensor(), # 将图片数据转成tensor格式
7 torchvision.transforms.Normalize( # 归一化
8 (0.1307,), (0.3081,))
9 ])),
10 batch_size=10, shuffle=False) # 10张图片
三、显示效果
1 def plot_image(img, label, name):
2 fig = plt.figure()
3 for i in range(6): # 只显示6张
4 plt.subplot(2, 3, i+1) # 2行3列第i+1张
5 plt.tight_layout()
6 plt.imshow(img[i][0]*0.3081+0.1307, cmap='Greys', interpolation='none')
7 plt.title("{}:{}".format(name, label[i].item())) # 标题名称
8 plt.xticks([])
9 plt.yticks([])
10 plt.show()
11
12 x, y = next(iter(data_loader)) # 文件夹的名称即为图片的label
13 print(x.shape, y.shape)
14 plot_image(x, y, 'image')