import tensorflow as tf import os """ uint8:存储,节约空间,比如在图片处理时,图片解码之前 float32:矩阵计算,提高精度,比如在图片处理时,图片解码之后 """ # 训练数据连接:http://www.cs.toronto.edu/~kriz/cifar.html # 定义cifar的数据命令行参数 FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("cifar_dir", r"C:UsersAdministratorPycharmProjectslearntest ensodatacifar10", "文件的目录") class CifarRead(object): """读取二进制文件,写入tfrecords,读取tfrecords""" def __init__(self, filelist): # 文件列表 self.filelist = filelist # 定义读取图片的一些属性,cifar下载的文件默认是32*32像素,彩色通道3,目标值1比特 self.height = 32 self.weight = 32 self.channel = 3 self.label_bytes = 1 # 二进制文件每张图片的字节 self.bytes = self.height * self.weight * self.channel + self.label_bytes def read_and_decode(self): # 1.构造文件队列 file_queue = tf.train.string_input_producer(self.filelist) # 2.构造二进制文件读取器 reader = tf.FixedLengthRecordReader(self.bytes) key, value = reader.read(file_queue) # 3.二进制文件内容解码 label_image = tf.decode_raw(value, tf.uint8) # 4.将label_image中的特征值和目标值分割开来,cast目标值是0-9的整数所以转换成int32类型,特征值将用于计算,转换成float32类型 label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32) image = tf.cast(tf.slice(label_image, [self.label_bytes], [self.bytes - self.label_bytes]), tf.float32) # print(label, image) # 返回结果Tensor("Slice:0", shape=(1,), dtype=uint8) Tensor("Slice_1:0", shape=(3072,), dtype=uint8) # 5.可以对图片特征数据进行形状改变[3072] ==> [32, 32, 3] image_reshape = tf.reshape(image, [self.height, self.weight, self.channel]) # 6.进行批处理 image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10) return image_batch, label_batch if __name__ == "__main__": # 构造文件列表 file_name = os.listdir(FLAGS.cifar_dir) filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"] cf = CifarRead(filelist) image_batch, label_batch = cf.read_and_decode() # 开启会话 with tf.Session() as sess: # 定义线程协调器 coord = tf.train.Coordinator() # 开启读取文件的线程 thd = tf.train.start_queue_runners(sess, coord=coord, start=True) # 打印读取内容 print(sess.run([image_batch, label_batch])) # 回收子线程 coord.request_stop() coord.join(thd)