• 自己写的制作 city的语义分割tfrecord 适用于deeplabv3+


    自己写的制作 city的语义分割tfrecord  适用于deeplabv3+

    自用

    """Converts PASCAL dataset to TFRecords file format."""
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import io
    import os
    import sys
    import natsort
    import PIL.Image
    import tensorflow as tf
    
    from utils import dataset_util
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--data_dir', type=str, default='/home/a/dataset/cityscapes/',
                        help='Path to the directory containing the PASCAL VOC data.')
    
    parser.add_argument('--output_path', type=str, default='./dataset',
                        help='Path to the directory to create TFRecords outputs.')
    
    parser.add_argument('--train_data_list', type=str, default='./dataset/train.txt',
                        help='Path to the file listing the training data.')
    
    parser.add_argument('--valid_data_list', type=str, default='./dataset/val.txt',
                        help='Path to the file listing the validation data.')
    
    parser.add_argument('--image_data_dir', type=str, default='leftImg8bit',
                        help='The directory containing the image data.')
    
    parser.add_argument('--label_data_dir', type=str, default='gtFine',
                        help='The directory containing the augmented label data.')
    
    
    def dict_to_tf_example(image_path,
                           label_path):
      """Convert image and label to tf.Example proto.
    
      Args:
        image_path: Path to a single PASCAL image.
        label_path: Path to its corresponding label.
    
      Returns:
        example: The converted tf.Example.
    
      Raises:
        ValueError: if the image pointed to by image_path is not a valid JPEG or
                    if the label pointed to by label_path is not a valid PNG or
                    if the size of image does not match with that of label.
      """
      with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
      encoded_jpg_io = io.BytesIO(encoded_jpg)
      image = PIL.Image.open(encoded_jpg_io)
      if image.format != 'PNG':
        raise ValueError('Image format not PNG')
    
      with tf.gfile.GFile(label_path, 'rb') as fid:
        encoded_label = fid.read()
      encoded_label_io = io.BytesIO(encoded_label)
      label = PIL.Image.open(encoded_label_io)
      if label.format != 'PNG':
        raise ValueError('Label format not PNG')
    
      if image.size != label.size:
        raise ValueError('The size of image does not match with that of label.')
    
      width, height = image.size
    
      example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('png'.encode('utf8')),
        'label/encoded': dataset_util.bytes_feature(encoded_label),
        'label/format': dataset_util.bytes_feature('png'.encode('utf8')),
      }))
      return example
    def scanDir_img_File(dir):
        for root, dirs, files in os.walk(dir, True, None, False):  # 遍列目录
            for f in files:
                yield os.path.join(root,f)
    
    def scanDir_lable_File(dir):
        for root, dirs, files in os.walk(dir, True, None, False):  # 遍列目录
            # 处理该文件夹下所有文件:
    
            for f in files:
                if os.path.isfile(os.path.join(root, f)):
                    a = os.path.splitext(f)
                    lable = a[0].split('_')[4]
                    # print(lable)
                    if lable in ('labelTrainIds'):
                        # print(os.path.join(root,f))
                        yield os.path.join(root,f)
    
    def create_tf_record(output_filename,
                         image_dir,
                         label_dir):
      """Creates a TFRecord file from examples.
    
      Args:
        output_filename: Path to where output file is saved.
        image_dir: Directory where image files are stored.
        label_dir: Directory where label files are stored.
      """
      imgg = []
      writer = tf.python_io.TFRecordWriter(output_filename)
    
      img = scanDir_img_File(image_dir)
      for imgs in img:
        imgg.append(imgs)
      image_list = natsort.natsorted(imgg)
    
      lable = scanDir_lable_File(label_dir)
      lablee = []
      for lables in lable:
        lablee.append(lables)
      label_list = natsort.natsorted(lablee)
      for image_path,label_path in zip(image_list,label_list):
        print(image_path,label_path)
        try:
          tf_example = dict_to_tf_example(image_path, label_path)
          writer.write(tf_example.SerializeToString())
        except ValueError:
          tf.logging.warning('Invalid example: %s, ignoring.')
    
      writer.close()
    
    
    def main(unused_argv):
      if not os.path.exists(FLAGS.output_path):
        os.makedirs(FLAGS.output_path)
    
      tf.logging.info("Reading from CITY dataset")
      train_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir,'train')
      train_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir,'train')
      val_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir, 'val')
      val_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir, 'val')
    
      train_output_path = os.path.join(FLAGS.output_path, 'city_train.record')
      val_output_path = os.path.join(FLAGS.output_path, 'city_val.record')
    
      create_tf_record(train_output_path, train_image_dir, train_label_dir)
      create_tf_record(val_output_path, val_image_dir, val_label_dir)
    
    
    if __name__ == '__main__':
      tf.logging.set_verbosity(tf.logging.INFO)
      FLAGS, unparsed = parser.parse_known_args()
      tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
     
  • 相关阅读:
    延迟加载时发生no session错误的解决办法
    零零散散的一些知识点(一)
    零零散散的一些知识点(二)
    自己写的一个日历表
    js复制网址
    load方法在延迟加载时可能出现的错误。
    JSON基本介绍
    JBOSS4.0 JDBC数据源配置大全
    EJB学习笔记一
    Android程序完全退出的方法
  • 原文地址:https://www.cnblogs.com/ansang/p/8631857.html
Copyright © 2020-2023  润新知