• 制作tfrecord 代码——可用任意照片均可


    代码:

      1 # -*- coding: utf-8 -*-
      2 # @Time    : 2018/11/23 0:09
      3 # @Author  : MaochengHu
      4 # @Email   : wojiaohumaocheng@gmail.com
      5 # @File    : generate_tfrecord.py
      6 # @Software: PyCharm
      7 
      8 import os
      9 import tensorflow as tf
     10 import io
     11 from PIL import Image
     12 import json
     13 def get_annotation_dict(input_folder_path, word2number_dict):
     14     label_dict = {}
     15     father_file_list = os.listdir(input_folder_path)
     16     for father_file in father_file_list:
     17         full_father_file = os.path.join(input_folder_path, father_file)
     18         son_file_list = os.listdir(full_father_file)
     19         for image_name in son_file_list:
     20             label_dict[os.path.join(full_father_file, image_name)] = word2number_dict[father_file]
     21     return label_dict
     22 
     23 
     24 def save_json(label_dict, json_path):
     25     with open(json_path, 'w') as json_path:
     26         json.dump(label_dict, json_path)
     27     print("label json file has been generated successfully!")
     28 
     29 
     30 
     31 def int64_feature(value):
     32     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
     33 
     34 
     35 def bytes_feature(value):
     36     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
     37 
     38 
     39 def process_image_channels(image):
     40     process_flag = False
     41     # process the 4 channels .png
     42     if image.mode == 'RGBA':
     43         r, g, b, a = image.split()
     44         image = Image.merge("RGB", (r,g,b))
     45         process_flag = True
     46     # process the channel image
     47     elif image.mode != 'RGB':
     48         image = image.convert("RGB")
     49         process_flag = True
     50     return image, process_flag
     51 
     52 
     53 def process_image_reshape(image, resize):
     54     width, height = image.size
     55     if resize is not None:
     56         if width > height:
     57              width = int(width * resize / height)
     58              height = resize
     59         else:
     60             width = resize
     61             height = int(height * resize / width)
     62         image = image.resize((width, height), Image.ANTIALIAS)
     63     return image
     64 
     65 
     66 def create_tf_example(image_path, label, resize=None):
     67     with tf.gfile.GFile(image_path, 'rb') as fid:
     68         encode_jpg = fid.read()
     69     encode_jpg_io = io.BytesIO(encode_jpg)
     70     image = Image.open(encode_jpg_io)
     71     # process png pic with four channels
     72     image, process_flag = process_image_channels(image)
     73     # reshape image
     74     image = process_image_reshape(image, resize)
     75     if process_flag == True or resize is not None:
     76         bytes_io = io.BytesIO()
     77         image.save(bytes_io, format='JPEG')
     78         encoded_jpg = bytes_io.getvalue()
     79     width, height = image.size
     80     tf_example = tf.train.Example(
     81         features=tf.train.Features(
     82             feature={
     83                 'image/encoded': bytes_feature(encode_jpg),
     84                 'image/format': bytes_feature(b'jpg'),
     85                 'image/class/label': int64_feature(label),
     86                 'image/height': int64_feature(height),
     87                 'image/width': int64_feature(width)
     88             }
     89         ))
     90     return tf_example
     91 
     92 
     93 def generate_tfrecord(annotation_dict, record_path, resize=None):
     94     num_tf_example = 0
     95     writer = tf.io.TFRecordWriter(record_path)
     96     for image_path, label in annotation_dict.items():
     97         if not tf.gfile.GFile(image_path):
     98             print("{} does not exist".format(image_path))
     99         tf_example = create_tf_example(image_path, label, resize)
    100         writer.write(tf_example.SerializeToString())
    101         num_tf_example += 1
    102         if num_tf_example % 100 == 0:
    103             print("Create %d TF_Example" % num_tf_example)
    104     writer.close()
    105     print("{} tf_examples has been created successfully, which are saved in {}".format(num_tf_example, record_path))
    106 
    107 
    108 
    109 
    110 if __name__ == '__main__':
    111     word2number_dict = {
    112         "combinations": 0,
    113         "details": 1,
    114         "sizes": 2,
    115         "tags": 3,
    116         "models": 4,
    117         "tileds": 5,
    118         "hangs": 6
    119     }
    120     images_dir = '../images_root'
    121     #annotation_path = FLAGS.annotation_path
    122     record_path = 'train.record'
    123     annotation_dict = get_annotation_dict(images_dir, word2number_dict)
    124     print(annotation_dict)
    125     print("AAA")
    126     generate_tfrecord(annotation_dict, record_path)
  • 相关阅读:
    OC-KVO简介
    注册审核
    应用权限
    关于函数执行的一点知识
    设置权限
    文件操作实例:文件管理器(网页版)
    文件操作
    正则表达式
    全局变量和递归
    案例:简单留言板
  • 原文地址:https://www.cnblogs.com/smartisn/p/12438856.html
Copyright © 2020-2023  润新知