• tensorflow 相关


    1.从checkpoint中获取全部的变量名和变量值

    tf.contrib.framework.list_variables(model_dir)
    tf.contrib.framework.load_variable(model_dir, var_name)

    2.清除 tf.Session

    tf.reset_default_graph() 重置计算图

    3。 使用tf_record

    numpy数据可以直接制作dataset ds = tf.data.Dataset.from_tensor_slices(trg)

    正常情况的话

     1 a = np.random.randint(0,10,(10))
     2  2 b = np.random.rand(10,20)
     3  3 a1 = a.tobytes()
     4  4 b1 = b.tobytes()
     5  5 writer= tf.python_io.TFRecordWriter("./tfr/train.tfrecords")
     6  6 example = tf.train.Example(features=tf.train.Features(feature={
     7  7             "soft_targets": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b1])),
     8  8             'src_wids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[a1]))
     9  9         }))
    10 10 for _ in range(10):
    11 11     writer.write(example.SerializeToString())
    12 12 writer.close()
    View Code
     1  def _parse_function(example_proto):
     2  2     features = tf.parse_single_example(
     3  3         example_proto,
     4  4         features={
     5  5             'src_wids': tf.FixedLenFeature([], tf.string),
     6  6             'soft_targets': tf.FixedLenFeature([], tf.string)
     7  7         }
     8  8     )
     9  9     # 取出我们需要的数据(标签,图片)
    10 10     label = features['soft_targets']
    11 11     feature = features['src_wids']
    12 12     label = tf.decode_raw(label, tf.float32)
    13 13     feature = tf.decode_raw(feature, tf.int64)
    14 14     return feature, label
    15 15 
    16 16 dataset = tf.contrib.data.TFRecordDataset("./tfr/train.tfrecords")
    17 17 dataset = dataset.map(_parse_function)
    18 18 dataset = dataset.batch(2)
    19 19 iterator = dataset.make_initializable_iterator()
    20 20 with tf.Session() as sess:
    21 21     sess.run(iterator.initializer)
    22 22     ids,trg = sess.run(iterator.get_next())
    View Code
      1     #Author:Anthony  
      2     #导入相应的模块  
      3     import tensorflow as tf  
      4     import os  
      5     import random  
      6     import math  
      7     import sys  
      8     #划分验证集训练集  
      9     _NUM_TEST = 40  
     10     #random seed  
     11     _RANDOM_SEED = 0  
     12     #数据块  
     13     _NUM_SHARDS = 2  
     14     #数据集路径  
     15     DATASET_DIR = '/home/anthony/文档/数据集_带标签/SHIYAN_SAMEZIZE'  
     16     #标签文件  
     17     LABELS_FILENAME = '/home/anthony/文档/数据集_带标签/SHIYAN_SAMEZIZE_labels.txt'  
     18     #定义tfrecord 的路径和名称  
     19     def _get_dataset_filename(dataset_dir,split_name,shard_id):  
     20         output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name,shard_id,_NUM_SHARDS)  
     21         return os.path.join(dataset_dir,output_filename)  
     22     #判断tfrecord文件是否存在  
     23     def _dataset_exists(dataset_dir):  
     24         for split_name in ['train','test']:  
     25             for shard_id in range(_NUM_SHARDS):  
     26                 #定义tfrecord的路径名字  
     27                 output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)  
     28             if not tf.gfile.Exists(output_filename):  
     29                 return False  
     30         return True  
     31     #获取图片以及分类  
     32     def _get_filenames_and_classes(dataset_dir):  
     33         #数据目录  
     34         directories = []  
     35         #分类名称  
     36         class_names = []  
     37         for filename in os.listdir(dataset_dir):  
     38             #合并文件路径  
     39             path = os.path.join(dataset_dir,filename)  
     40             #判断路径是否是目录  
     41             if os.path.isdir(path):  
     42                 #加入数据目录  
     43                 directories.append(path)  
     44                 #加入类别名称  
     45                 class_names.append(filename)  
     46         photo_filenames = []  
     47         #循环分类的文件夹  
     48         for directory in directories:  
     49             for filename in os.listdir(directory):  
     50                 path = os.path.join(directory,filename)  
     51                 #将图片加入图片列表中  
     52                 photo_filenames.append(path)  
     53         #返回结果  
     54         return photo_filenames ,class_names  
     55     def int64_feature(values):  
     56         if not isinstance(values,(tuple,list)):  
     57             values = [values]  
     58         return tf.train.Feature(int64_list=tf.train.Int64List(value=values))  
     59     def bytes_feature(values):  
     60         return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))  
     61     #图片转换城tfexample函数  
     62     def image_to_tfexample(image_data,image_format,class_id):  
     63         return tf.train.Example(features=tf.train.Features(feature={  
     64             'image/encoded': bytes_feature(image_data),  
     65             'image/format': bytes_feature(image_format),  
     66             'image/class/label': int64_feature(class_id)  
     67         }))  
     68     def write_label_file(labels_to_class_names,dataset_dir,filename=LABELS_FILENAME):  
     69         label_filename = os.path.join(dataset_dir,filename)  
     70         with tf.gfile.Open(label_filename,'w') as f:  
     71             for label in labels_to_class_names:  
     72                 class_name = labels_to_class_names[label]  
     73                 f.write('%d:%s
    ' % (label, class_name))  
     74     #数据转换城tfrecorad格式  
     75     def _convert_dataset(split_name,filenames,class_names_to_ids,dataset_dir):  
     76         assert split_name in ['train','test']  
     77         #计算每个数据块的大小  
     78         num_per_shard = int(len(filenames) / _NUM_SHARDS)  
     79         with tf.Graph().as_default():  
     80             with tf.Session() as sess:  
     81                 for shard_id in range(_NUM_SHARDS):  
     82                 #定义tfrecord的路径名字  
     83                     output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)  
     84                     with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:  
     85                         #每个数据块开始的位置  
     86                         start_ndx = shard_id * num_per_shard  
     87                         #每个数据块结束的位置  
     88                         end_ndx = min((shard_id+1) * num_per_shard,len(filenames))  
     89                         for i in range(start_ndx,end_ndx):  
     90                             try:  
     91                                 sys.stdout.write('
    >> Converting image %d/%d shard %d '% (i+1,len(filenames),shard_id))  
     92                                 sys.stdout.flush()  
     93                                 #读取图片  
     94                                 image_data = tf.gfile.FastGFile(filenames[i],'rb').read()  
     95                                 #获取图片的类别名称  
     96                                 #basename获取图片路径最后一个字符串  
     97                                 #dirname是除了basename之外的前面的字符串路径r  
     98                                 class_name = os.path.basename(os.path.dirname(filenames[i]))  
     99                                 #获取图片的id  
    100                                 class_id = class_names_to_ids[class_name]  
    101                                 #生成tfrecord文件  
    102                                 example = image_to_tfexample(image_data,b'jpg',class_id)  
    103                                 #写入数据  
    104                                 tfrecord_writer.write(example.SerializeToString())  
    105                             except IOError  as e:  
    106                                 print ('could not read:',filenames[1])  
    107                                 print ('error:' , e)  
    108                                 print ('skip it 
    ')  
    109         sys.stdout.write('
    ')  
    110         sys.stdout.flush()  
    111       
    112     if __name__ == '__main__':  
    113         #判断tfrecord文件是否存在  
    114         if _dataset_exists(DATASET_DIR):  
    115             print ('tfrecord exists')  
    116         else:  
    117             #获取图片以及分类  
    118             photo_filenames,class_names = _get_filenames_and_classes(DATASET_DIR)  
    119             #将分类的list转换成dictionary{‘house':3,'flowers:2'}  
    120             class_names_to_ids = dict(zip(class_names,range(len(class_names))))  
    121             #切分数据为测试训练集  
    122             random.seed(_RANDOM_SEED)  
    123             random.shuffle(photo_filenames)  
    124             training_filenames = photo_filenames[_NUM_TEST:]  
    125             testing_filenames = photo_filenames[:_NUM_TEST]  
    126             #数据转换  
    127             _convert_dataset('train',training_filenames,class_names_to_ids,DATASET_DIR)  
    128             _convert_dataset('test',testing_filenames,class_names_to_ids,DATASET_DIR)  
    129             #输出lables文件  
    130             #与前面的 class_names_to_ids中的元素位置相反{1:'people,2:'flowers'}  
    131             labels_to_class_names = dict(zip(range(len(class_names)),class_names))  
    132             write_label_file(labels_to_class_names,DATASET_DIR)
    View Code

     4.tf,argmax(data, axis)

    获取最大元素的索引

    -------------------------------------------keras--------------------------------------

    1 . Keras---text.Tokenizer: https://blog.csdn.net/lovebyz/article/details/77712003

    
    
    
    
    
  • 相关阅读:
    【SSH网上商城项目实战05】完成数据库的级联查询和分页
    后台dubug有值且sql也打印出来执行了但是前台就是查不到数据
    Caused by: java.lang.NoSuchMethodError: javax.persistence.JoinColumn.foreignKey()Ljavax/persistence/
    异常:Caused by: java.lang.NoSuchMethodError: javax.persistence.OneToMany.orphanRemoval()Z/Caused by: java.lang.NoSuchMethodError: javax.persistence.JoinColumn.foreign
    @Resource或者@Autowired作用/Spring中@Autowired注解、@Resource注解的区别
    【SSH网上商城项目实战04】EasyUI菜单的实现
    【SSH网上商城项目实战03】使用EasyUI搭建后台页面框架
    【SSH网上商城项目实战02】基本增删查改、Service和Action的抽取以及使用注解替换xml
    【SSH网上商城项目实战01】整合Struts2、Hibernate4.3和Spring4.2
    【SpringMVC学习01】宏观上把握SpringMVC框架
  • 原文地址:https://www.cnblogs.com/wb-learn/p/11596980.html
Copyright © 2020-2023  润新知