• 将本地图片数据制作成内存对象数据集|tensorflow|手写数字制作成内存对象数据集|tf队列|线程


     

    样本说明:

      tensorflow经典实例之手写数字识别。MNIST数据集。

    数据集dir名称

    每个文件夹代表一个标签label,每个label中有820个手写数字的图片

    标签label为0的文件夹中部分bmp图片示例

    import tensorflow as tf
    import os
    from matplotlib import pyplot as plt
    import numpy as np
    from sklearn.utils import shuffle
    相关模块
    def load_sample(sample_dir):
        '''
        9个标签(label)对应9*820个图片,这一步将每个图片的相对路径及图片名称制作成列表    每个推按对应的标签label制作成列表
        :param sample_dir: mnist_digits_images
        :return: 文件名称列表  标签列表
        '''
        print ('正在导入样本数据..')
        lfilenames = []   ###用于接收每个图片名称的空列表
        labelsnames = []  ###用于接收每个图片对应的标签label的空列表
        for (dirpath, dirnames, filenames) in os.walk(sample_dir):#递归遍历文件夹
            '''os.walk的运行规则,自行补充'''
            for filename in filenames:                            #遍历所有文件名
                filename_path = os.sep.join([dirpath, filename])
                lfilenames.append(filename_path)               #添加文件名
                labelsnames.append( dirpath.split('\')[-1] )#添加文件名对应的标签
        ###此时得到的标签列表是字符串类型的,下面将字符串列表转换成数字列表
        lab= list(sorted(set(labelsnames)))  #标签列表去重,set有去重功能,sorted对set排序
        labdict=dict( zip( lab  ,list(range(len(lab)))  )) #生成字典:字符串{'0':0,'1':1,'2':2,......}
        labels = [labdict[i] for i in labelsnames]         ##通过列表解析,将字符串标签转换成数字标签
        ##将列表转换成数组,并且使用shuffle乱序数组,统一个shullfe下,各个列表乱序的规则一致,所以图片标签一一对应的关系不变
        return shuffle(np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)  
    文件名称数组(列表) 和 标签数组(列表)
    def get_batches(image,label,input_w,input_h,channels,batch_size):
        ##在会话启动队列之后,进行入队和出队操作。出队一个数据,为queue=[image,label],image:图片名称(相对路径) label :标签
        queue = tf.train.slice_input_producer([image,label])  #使用tf.train.slice_input_producer实现一个输入的队列(准备,此时并没有数据入队)
    
        label = queue[1]                                        #从输入队列里读取标签
    
        image_c = tf.read_file(queue[0])                        #从输入队列里读取image路径【读取图片】
    
        image = tf.image.decode_bmp(image_c,channels)           #将读取的图片解码
    
        image = tf.image.resize_image_with_crop_or_pad(image,input_w,input_h) #修改图片大小
    
    
        image = tf.image.per_image_standardization(image) #图像标准化处理,(x - mean) / adjusted_stddev
    
        image_batch,label_batch = tf.train.batch([image,label],#调用tf.train.batch函数生成批次数据
                   batch_size = batch_size,
                   num_threads = 64)
    
        images_batch = tf.cast(image_batch,tf.float32)   #将数据类型转换为float32
    
        labels_batch = tf.reshape(label_batch,[batch_size])#修改标签的形状shape
        return images_batch,labels_batch
    通过队列读取图片并对图片进行处理---生成批次数据
    data_dir = 'mnist_digits_images\'  #定义文件路径
    
    (image,label),labelsnames = load_sample(data_dir)   #载入文件名称与标签
    batch_size = 16        ##定义批次大小
    image_batches,label_batches = get_batches(image,label,28,28,1,batch_size)  ##读取图片,预处理,生成批次
    with tf.Session() as sess:
        '''会话初始化'''
        init = tf.global_variables_initializer()
        sess.run(init)
        '''创建队列协调器,在get_batches函数中定义了tf.train.slice_input_producer输入队列,创建协调器,并启动队列就可以用队列调用数据了'''
        coord = tf.train.Coordinator()
        ###启动队列线程
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        try:
            for step in np.arange(5): ##定义迭代次数:每迭代一次,一个批次的数据注入
                ###查看队列是否开启(队列关闭返回True,开启返回False),如果关闭,终止程序
    
                if coord.should_stop():
    
                    break
                ##获取批次数据输入(注入数据)
                images,labels = sess.run([image_batches,label_batches])
                print('一个批次图片(数量)',len(images))
                print('一个批次标签', labels)
        except tf.errors.OutOfRangeError:
            print('完成,现在开始终止所有线程')
        finally:
            coord.request_stop()
            print('所有线程请求终止')
        coord.join(threads)
        print('所有线程终止')
    
    '''
    输出结果:
    一个批次图片(数量) 16
    一个批次标签 [1 1 6 5 8 1 7 5 1 8 5 6 4 5 9 7]
    一个批次图片(数量) 16
    一个批次标签 [9 7 2 3 6 7 1 1 8 0 3 4 7 7 7 9]
    一个批次图片(数量) 16
    一个批次标签 [6 2 7 9 0 5 3 9 5 0 4 1 3 3 2 6]
    一个批次图片(数量) 16
    一个批次标签 [3 1 9 7 9 0 9 1 7 3 9 8 9 4 1 9]
    一个批次图片(数量) 16
    一个批次标签 [0 2 3 3 9 8 7 9 9 9 8 6 1 6 9 1]
    所有线程请求终止
    所有线程终止
    '''
  • 相关阅读:
    laravel 验证码手机与提交手机的验证?
    微信公众平台开发——微信授权登录(OAuth2.0)
    个人网站可以申请微信授权登录吗
    个人网站可以申请微信授权登录吗?
    个体户微信公众号认证怎么做?无公章
    [微信开发] 没有组织机构代码证、公章怎么认证微信公众号?
    mysql中int、bigint、smallint 和 tinyint的区别详细介绍
    laravel5.6 QQ 第三方登录
    如何给网站的链接设置为绝对地址原文链接
    ArcGIS中文件共享锁定数据溢出 这个方法不行,建议用gdb,不要用mdb
  • 原文地址:https://www.cnblogs.com/liuhuacai/p/11726487.html
Copyright © 2020-2023  润新知