• 使用tensorflow中的Dataset来读取制作好的tfrecords文件


    上一篇我写了如何给自己的图像集制作tfrecords文件,现在我们就来讲讲如何读取已经创建好的文件,我们使用的是Tensorflow中的Dataset来读取我们的tfrecords,网上很多帖子应该是很久之前的了,绝大多数的做法是,先将tfrecords序列化成一个队列,然后使用TFRecordReader这个函数进行解析,解析出来的每一行都是一个record,然后再将每一个record进行还原,但是这个函数你在使用的时候会报出异常,原因就是它已经被dataset中新的读取方式所替代,下个版本中可能就无法使用了,因此不建议大家使用这个函数,好了,下面就来看看是如何进行读取的吧。

     1 import tensorflow as tf
     2 import matplotlib.pyplot as plt
     3 
     4 #定义可以一次获得多张图像的函数
     5 def show_image(image_dir):
     6     plt.imshow(image_dir)
     7     plt.axis('on')
     8     plt.show()
     9 
    10 #单个record的解析函数
    11 def decode_example(example):#,resize_height,resize_width,labels_nums):
    12     features=tf.io.parse_single_example(example,features={
    13         'image_raw':tf.io.FixedLenFeature([],tf.string),
    14         'label':tf.io.FixedLenFeature([],tf.int64)
    15     })
    16     tf_image=tf.decode_raw(features['image_raw'],tf.uint8)#这个其实就是图像的像素模式,之前我们使用矩阵来表示图像
    17     tf_image=tf.reshape(tf_image,shape=[224,224,3])#对图像的尺寸进行调整,调整成三通道图像
    18     tf_image=tf.cast(tf_image,tf.float32)*(1./255)#对图像进行归一化以便保持和原图像有相同的精度
    19     tf_label=tf.cast(features['label'],tf.int32)
    20     tf_label=tf.one_hot(tf_label,5,on_value=1,off_value=0)#将label转化成用one_hot编码的格式
    21     return tf_image,tf_label
    22 
    23 def batch_test(tfrecords_file):
    24     dataset=tf.data.TFRecordDataset(tfrecords_file)
    25     dataset=dataset.map(decode_example)
    26     dataset=dataset.shuffle(100).batch(4)
    27     iterator=tf.compat.v1.data.make_one_shot_iterator(dataset)
    28     batch_images,batch_labels=iterator.get_next()
    29 
    30     init_op=tf.compat.v1.global_variables_initializer()
    31     with tf.compat.v1.Session() as sess:
    32         sess.run(init_op)
    33         coord=tf.train.Coordinator()
    34         threads=tf.train.start_queue_runners(coord=coord)
    35         for i in range(4):
    36             images,labels=sess.run([batch_images,batch_labels])
    37             show_image(images[1,:,:,:])
    38             print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))
    39 
    40         coord.request_stop()
    41         coord.join(threads)
    42 
    43 if __name__=='__main__':
    44     tfrecords_file='D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords'
    45     resize_height=224
    46     resize_width=224
    47     batch_test(tfrecords_file)

    我为了测试,写了batch_test这个函数,因为我想试一试看我做的tfrecords能不能被解析成功,如果你不想测试只想训练,那你直接把images_batch,和labels_batch放到网络中进行训练就可以了,还有一点要注意的,tf.global_variables_initializer()已经被tf.compat.v1.global_variables_initializer()所取代了,我做的时候不知道所以报了一个warning提示,同时tf.Sesssion()已经被tf.compat.v1.Session() 所替代,iterator=dataset.make_one_shot_iterator()已经被tf.compat.v1.data.make_one_shot_iterator(dataset)  所代替,这些异常要注意,然后我只是将每个batch的第二张图片显示出来了,你也可以显示其他的,但是意义不大,反正只是测试一下解析成功与否,成功了我们就不需要纠结别的了。好啦,就是这样,接下来我会把这些东西放到网络中进行训练,再更新我的学习,就酱。

  • 相关阅读:
    网页字体大小控制
    表格文本框搜索匹配
    表格展开和关闭
    表格复选框控制行高亮
    jquery表单验证
    文本框变大变小效果--jQuery
    滚动条高度变化jQuery
    点击标题显示隐藏效果--jQuery
    jQuery练习2-1
    jQuery练习2
  • 原文地址:https://www.cnblogs.com/daremosiranaihana/p/11444705.html
Copyright © 2020-2023  润新知