明天吧再来一个读取tfrecords的,昨天做的时候遇到了问题,电脑不行,老显示一些库函数不存在,其实库已经导入进去了,但是python就是这样,所以还没入坑的小伙伴去学caffe吧。不要被python毒害了。把代码粘上,有几个函数是没有用的,看之前大神的帖子上的,他做了好多函数来测试他的records有没有做成功,就是厉害,大神就是大神。
import tensorflow as tf import numpy as np import os import random from PIL import Image def _int64_feature(label): return tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) def _bytes_feature(imgdir): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgdir])) def float_list_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def get_example_nums(tf_records_filenames): nums= 0 for record in tf.python_io.tf_record_iterator(tf_records_filenames): nums += 1 return nums def get_example_num(records_file_dir): nums=0 for record in tf.io.tf_record_iterator(records_file_dir): nums+=1 return nums def load_file(imagestxtdir,shuffle=False): images=[]#存储各个集中图像地址的列表 labels=[] with open(imagestxtdir) as f: lines_list=f.readlines()#读取文件列表中所有的行 if shuffle: random.shuffle(lines_list)#将图像库中的图像地址进行随机的打乱 for line in lines_list: line_list=line.rstrip().split(' ')#rstrip函数是将每一行首尾的空白都去除然后 label=[] for i in range(1): label.append(int(line_list[i+1])) #cur_img_dir=images_base_dir+'/'+line_list[0] images.append(line_list[0]) labels.append(label) return images,labels def get_batch_images(images,labels,batch_size,labels_num,one_hot=False,shuffle=False,num_threads=1): min_after_dequeue=200 capacity=min_after_dequeue+3*batch_size if shuffle: images_batch,labels_batch=tf.train.shuffle_batch([images,labels], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue, num_threads=num_threads) else: images_batch,labels_batch=tf.train.batch([images,labels], batch_size=batch_size, num_threads=num_threads, capacity=capacity) if one_hot: labels_batch=tf.one_hot(labels_batch,labels_num,1,0) return images_batch,labels_batch def create_tf_records(image_base_dir,image_txt_dir,tfrecords_dir,resise_height,resize_weight,shuffle,log=5): images_list,labels_list=load_file(image_txt_dir,shuffle) writer=tf.io.TFRecordWriter(tfrecords_dir) for i,[image_name,single_label_list] in enumerate(zip(images_list,labels_list)): cur_image_dir=image_base_dir+'/'+images_list[i] if not os.path.exists(cur_image_dir): print('the image path is not exists') continue image=Image.open(cur_image_dir) image=image.resize((resise_height,resize_weight)) image_raw=image.tobytes() single_label=single_label_list[0] if i % log == 0 or i == len(images_list) - 1: print('------------processing:%d-th------------' % (i)) example=tf.train.Example(features=tf.train.Features(feature={ 'image_raw':_bytes_feature(image_raw), 'label':_int64_feature(single_label) })) writer.write(example.SerializeToString()) writer.close() if __name__=='__main__': resize_height=224 resize_width=224 shuffle=True log=5 train_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/train' train_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/train.txt' train_records_dir='D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords' create_tf_records(train_image_dir,train_txt_dir,train_records_dir,resize_height,resize_width,shuffle,log) train_nums=get_example_nums(train_records_dir) print('the train records number is:',train_nums) validation_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/validation' validation_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/validation.txt' validation_records_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/validation.tfrecords' create_tf_records(validation_image_dir,validation_txt_dir,validation_records_dir,resize_height, resize_width, shuffle, log) validation_nums = get_example_nums(validation_records_dir) print('the validation records number is:', validation_nums) test_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/test' test_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/test.txt' test_records_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/test.tfrecords' create_tf_records(test_image_dir, test_txt_dir, test_records_dir, resize_height, resize_width, shuffle, log) test_nums = get_example_nums(test_records_dir) print('the test records number is:', test_nums)
这个是我自己电脑的环境,就五类图像,如果你想做很多类的也行,都是一个道理,改一下路径就可以,注释懒得写了,因为代码写得比较简单,哈哈哈,想转的话随便转,但是真正想学的人还是得自己敲,但是我的博客写的很一般我估计没有人看应该