1.从checkpoint中获取全部的变量名和变量值
tf.contrib.framework.list_variables(model_dir)
tf.contrib.framework.load_variable(model_dir, var_name)
2.清除 tf.Session
tf.reset_default_graph() 重置计算图
3。 使用tf_record
numpy数据可以直接制作dataset ds = tf.data.Dataset.from_tensor_slices(trg)
正常情况的话
1 a = np.random.randint(0,10,(10)) 2 2 b = np.random.rand(10,20) 3 3 a1 = a.tobytes() 4 4 b1 = b.tobytes() 5 5 writer= tf.python_io.TFRecordWriter("./tfr/train.tfrecords") 6 6 example = tf.train.Example(features=tf.train.Features(feature={ 7 7 "soft_targets": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b1])), 8 8 'src_wids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[a1])) 9 9 })) 10 10 for _ in range(10): 11 11 writer.write(example.SerializeToString()) 12 12 writer.close()
1 def _parse_function(example_proto): 2 2 features = tf.parse_single_example( 3 3 example_proto, 4 4 features={ 5 5 'src_wids': tf.FixedLenFeature([], tf.string), 6 6 'soft_targets': tf.FixedLenFeature([], tf.string) 7 7 } 8 8 ) 9 9 # 取出我们需要的数据(标签,图片) 10 10 label = features['soft_targets'] 11 11 feature = features['src_wids'] 12 12 label = tf.decode_raw(label, tf.float32) 13 13 feature = tf.decode_raw(feature, tf.int64) 14 14 return feature, label 15 15 16 16 dataset = tf.contrib.data.TFRecordDataset("./tfr/train.tfrecords") 17 17 dataset = dataset.map(_parse_function) 18 18 dataset = dataset.batch(2) 19 19 iterator = dataset.make_initializable_iterator() 20 20 with tf.Session() as sess: 21 21 sess.run(iterator.initializer) 22 22 ids,trg = sess.run(iterator.get_next())
1 #Author:Anthony 2 #导入相应的模块 3 import tensorflow as tf 4 import os 5 import random 6 import math 7 import sys 8 #划分验证集训练集 9 _NUM_TEST = 40 10 #random seed 11 _RANDOM_SEED = 0 12 #数据块 13 _NUM_SHARDS = 2 14 #数据集路径 15 DATASET_DIR = '/home/anthony/文档/数据集_带标签/SHIYAN_SAMEZIZE' 16 #标签文件 17 LABELS_FILENAME = '/home/anthony/文档/数据集_带标签/SHIYAN_SAMEZIZE_labels.txt' 18 #定义tfrecord 的路径和名称 19 def _get_dataset_filename(dataset_dir,split_name,shard_id): 20 output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name,shard_id,_NUM_SHARDS) 21 return os.path.join(dataset_dir,output_filename) 22 #判断tfrecord文件是否存在 23 def _dataset_exists(dataset_dir): 24 for split_name in ['train','test']: 25 for shard_id in range(_NUM_SHARDS): 26 #定义tfrecord的路径名字 27 output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id) 28 if not tf.gfile.Exists(output_filename): 29 return False 30 return True 31 #获取图片以及分类 32 def _get_filenames_and_classes(dataset_dir): 33 #数据目录 34 directories = [] 35 #分类名称 36 class_names = [] 37 for filename in os.listdir(dataset_dir): 38 #合并文件路径 39 path = os.path.join(dataset_dir,filename) 40 #判断路径是否是目录 41 if os.path.isdir(path): 42 #加入数据目录 43 directories.append(path) 44 #加入类别名称 45 class_names.append(filename) 46 photo_filenames = [] 47 #循环分类的文件夹 48 for directory in directories: 49 for filename in os.listdir(directory): 50 path = os.path.join(directory,filename) 51 #将图片加入图片列表中 52 photo_filenames.append(path) 53 #返回结果 54 return photo_filenames ,class_names 55 def int64_feature(values): 56 if not isinstance(values,(tuple,list)): 57 values = [values] 58 return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 59 def bytes_feature(values): 60 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 61 #图片转换城tfexample函数 62 def image_to_tfexample(image_data,image_format,class_id): 63 return tf.train.Example(features=tf.train.Features(feature={ 64 'image/encoded': bytes_feature(image_data), 65 'image/format': bytes_feature(image_format), 66 'image/class/label': int64_feature(class_id) 67 })) 68 def write_label_file(labels_to_class_names,dataset_dir,filename=LABELS_FILENAME): 69 label_filename = os.path.join(dataset_dir,filename) 70 with tf.gfile.Open(label_filename,'w') as f: 71 for label in labels_to_class_names: 72 class_name = labels_to_class_names[label] 73 f.write('%d:%s ' % (label, class_name)) 74 #数据转换城tfrecorad格式 75 def _convert_dataset(split_name,filenames,class_names_to_ids,dataset_dir): 76 assert split_name in ['train','test'] 77 #计算每个数据块的大小 78 num_per_shard = int(len(filenames) / _NUM_SHARDS) 79 with tf.Graph().as_default(): 80 with tf.Session() as sess: 81 for shard_id in range(_NUM_SHARDS): 82 #定义tfrecord的路径名字 83 output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id) 84 with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: 85 #每个数据块开始的位置 86 start_ndx = shard_id * num_per_shard 87 #每个数据块结束的位置 88 end_ndx = min((shard_id+1) * num_per_shard,len(filenames)) 89 for i in range(start_ndx,end_ndx): 90 try: 91 sys.stdout.write(' >> Converting image %d/%d shard %d '% (i+1,len(filenames),shard_id)) 92 sys.stdout.flush() 93 #读取图片 94 image_data = tf.gfile.FastGFile(filenames[i],'rb').read() 95 #获取图片的类别名称 96 #basename获取图片路径最后一个字符串 97 #dirname是除了basename之外的前面的字符串路径r 98 class_name = os.path.basename(os.path.dirname(filenames[i])) 99 #获取图片的id 100 class_id = class_names_to_ids[class_name] 101 #生成tfrecord文件 102 example = image_to_tfexample(image_data,b'jpg',class_id) 103 #写入数据 104 tfrecord_writer.write(example.SerializeToString()) 105 except IOError as e: 106 print ('could not read:',filenames[1]) 107 print ('error:' , e) 108 print ('skip it ') 109 sys.stdout.write(' ') 110 sys.stdout.flush() 111 112 if __name__ == '__main__': 113 #判断tfrecord文件是否存在 114 if _dataset_exists(DATASET_DIR): 115 print ('tfrecord exists') 116 else: 117 #获取图片以及分类 118 photo_filenames,class_names = _get_filenames_and_classes(DATASET_DIR) 119 #将分类的list转换成dictionary{‘house':3,'flowers:2'} 120 class_names_to_ids = dict(zip(class_names,range(len(class_names)))) 121 #切分数据为测试训练集 122 random.seed(_RANDOM_SEED) 123 random.shuffle(photo_filenames) 124 training_filenames = photo_filenames[_NUM_TEST:] 125 testing_filenames = photo_filenames[:_NUM_TEST] 126 #数据转换 127 _convert_dataset('train',training_filenames,class_names_to_ids,DATASET_DIR) 128 _convert_dataset('test',testing_filenames,class_names_to_ids,DATASET_DIR) 129 #输出lables文件 130 #与前面的 class_names_to_ids中的元素位置相反{1:'people,2:'flowers'} 131 labels_to_class_names = dict(zip(range(len(class_names)),class_names)) 132 write_label_file(labels_to_class_names,DATASET_DIR)
4.tf,argmax(data, axis)
获取最大元素的索引
-------------------------------------------keras--------------------------------------
1 . Keras---text.Tokenizer: https://blog.csdn.net/lovebyz/article/details/77712003