• TensorFlow------TFRecords的读取实例


    TensorFlow------TFRecords的读取实例:

    import os
    import tensorflow as tf
    
    # 定义cifar的数据等命令行参数
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string('cifar_dir', './data/cifar10/cifar-10-batches-bin', '文件的目录')
    tf.app.flags.DEFINE_string('cifar_tfrecords', './tmp/cifar.tfrecords', '存储tfrecords的文件')
    
    
    class CifarRead(object):
        '''
        完成读取二进制文件,写进tfrecords,读取tfrecords
        :param object:
        :return:
        '''
    
        def __init__(self, filelist):
            # 文件列表
            self.file_list = filelist
    
            # 定义读取的图片的一些属性
            self.height = 32
            self.width = 32
            self.channel = 3
            # 二进制文件每张图片的字节
            self.label_bytes = 1
            self.image_bytes = self.height * self.width * self.channel
            self.bytes = self.label_bytes + self.image_bytes
    
        def read_and_decode(self):
            # 1. 构建文件队列
            file_queue = tf.train.string_input_producer(self.file_list)
    
            # 2. 构建二进制文件读取器,读取内容,每个样本的字节数
            reader = tf.FixedLengthRecordReader(self.bytes)
    
            key, value = reader.read(file_queue)
    
            # 3. 解码内容,二进制文件内容的解码 label_image包含目标值和特征值
            label_image = tf.decode_raw(value, tf.uint8)
            print(label_image)
    
            # 4.分割出图片和标签数据,特征值和目标值
            label = tf.slice(label_image, [0], [self.label_bytes])
    
            image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
            print('---->')
            print(image)
    
            # 5. 可以对图片的特征数据进行形状的改变 [3072]-->[32,32,3]
            image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
    
            print('======>')
            print(label)
            print('======>')
    
            # 6. 批处理数据
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    
            print(image_batch, label_batch)
    
            return image_batch, label_batch
        # 读取并存储tfrecords文件
        # def write_ro_tfrecords(self, image_batch, label_batch):
        #     '''
        #     将图片的特征值和目标值存进tfrecords
        #     :param image_batch: 10张图片的特征值
        #     :param label_batch: 10张图片的目标值
        #     :return: None
        #     '''
        #     # 1.建立TFRecord存储器
        #     writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
        #
        #     # 2. 循环将所有样本写入文件,每张图片样本都要构造example协议
        #     for i in range(10):
        #         # 取出第i个图片数据的特征值和目标值
        #         image = image_batch[i].eval().tostring()
        #
        #         label = int(label_batch[i].eval()[0])
        #
        #         # 构造一个样本的example
        #         example = tf.train.Example(features=tf.train.Features(feature={
        #             'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
        #             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
        #         }))
        #
        #         # 写入单独的样本
        #         writer.write(example.SerializeToString())
        #
        #     # 关闭
        #     writer.close()
        #     return None
    
        def read_from_tfrecords(self):
            # 1. 构造文件队列
            file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
    
            # 2. 构造文件阅读器,读取内容example,value一个样本的序列化example
            reader = tf.TFRecordReader()
    
            key, value = reader.read(file_queue)
    
            # 3. 解析example
            features = tf.parse_single_example(value, features={
                'image': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.int64),
            })
    
            print(features['image'], features['label'])
    
            # 4. 解码内容,如果读取的内容格式是string需要解码,如果是int64,float32不需要解码
            image = tf.decode_raw(features['image'], tf.uint8)
    
            # 固定图片的形状,方便与批处理
            image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
    
            label = tf.cast(features['label'], tf.int32)
    
            print(image_reshape, label)
    
            # 进行批处理
            image_batch,label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    
            return image_batch,label_batch
    
    
    if __name__ == '__main__':
        # 找到文件,构建列表  路径+名字  ->列表当中
        file_name = os.listdir(FLAGS.cifar_dir)
    
        # 拼接路径 重新组成列表
        filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == 'bin']
    
        # 调用函数传参
        cf = CifarRead(filelist)
        # image_batch,label_batch = cf.read_and_decode()
    
        image_batch, label_batch = cf.read_from_tfrecords()
    
        # 开启会话
        with tf.Session() as sess:
            # 定义一个线程协调器
            coord = tf.train.Coordinator()
    
            # 开启读文件的线程
            threads = tf.train.start_queue_runners(sess, coord=coord)
    
            # 存进tfrecords文件
            # print('开始存储')
            # cf.write_ro_tfrecords(image_batch,label_batch)
            # print('结束存储')
            # 打印读取的内容
            print(sess.run([image_batch,label_batch]))
    
            # 回收子线程
            coord.request_stop()
    
            coord.join(threads)
  • 相关阅读:
    寒假每日总结——2020.2.1
    亿级用户下的新浪微博平台架构读后感
    京东话费充值系统架构演讲读后感
    京东物流系统架构演讲中的最佳实践读后感
    京东上千页面搭建基石——CMS前后端分离演讲史读后感
    数据蜂巢架构演讲之路读后感
    关于SOA架构设计的案例分析下
    京东虚拟业务多维订单系统架构设计读后感
    在VUE-CLI 3下的第一个Element-ui项目(菜鸟专用)
    在vue-cli3中优雅的使用 icon
  • 原文地址:https://www.cnblogs.com/fwl8888/p/9762647.html
Copyright © 2020-2023  润新知