图像分类数据集CIFAR
- 此模块将从 https://www.cs.toronto.edu/~kriz/cifar.html 下载数据集,并将训练集和测试集解析为paddle reader creator。
- cifar-10数据集由10个类别的60000张32x32彩色图像组成,每个类别6000张图像。共有5万张训练图像,1万张测试图像
- cifar-100数据集与cifar-10类似,只是它有100个类
- 返回一个reader creator, reader中的每个样本的图像像素范围是[0,1],标签范围是[0,9]。
paddle.dataset.cifar:https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/data/dataset_cn/cifar_cn.html
paddle.batch:https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/io_cn/batch_cn.html#batch
import numpy as np
import paddle as paddle
import paddle.dataset.cifar as cifar
import paddle.fluid as fluid
from PIL import Image
import matplotlib.pyplot as plt
# 得到数据集迭代函数
train10=cifar.train10(cycle=False)
# cifar.test10(cycle=False)
# train100=cifar.train100(cycle=True)
# train10:<function paddle.dataset.cifar.reader_creator.<locals>.reader()>
a_sample=next(train10())
# data,label
# a_sample:array([0.69803923, 0.69803923, 0.69803923, ..., 0.3137255 , 0.3137255 ,0.3019608 ], dtype=float32), 0
# a_sample[0].shape:(3072,)
# 放到batch里面加载
train10_reader = paddle.batch(train10, batch_size=64)
# 得到一批一批的迭代函数
# train10_reader:<function paddle.batch.batch.<locals>.batch_reader()>
# 读取一个batch看看
for batch_id, data in enumerate(train10_reader()):
print(len(data)) # 64
break
# 图像像素范围是[0,1],标签范围是[0,9]
# data[i]: a sample
# data[i][0]: image data
# data[i][1]: label
# data[1][0].shape,data[1][1]:(3072,), 6
# data[0][0].all()>=0 and data[0][0].all()<=1:True
# 是彩色RGB图像
# 32*32=1024,3072/1024=3
# 32.0=np.sqrt(3072/3)
def load_image(im):
# RBG
im = im.reshape(32,32,3).astype(np.float32)
return im
def show(img):
plt.imshow(img)
plt.show()
i=54
label=data[i][1]
print(label) # 5
img=load_image(data[i][0])
print(img.shape) # (32, 32, 3)
'''
PS类别对应图像:
0:airplane
1:automobile
2:bird
3:cat
4:deer
5:dog
6:frog
7:horse
8:ship
'''
show(img)