• labelimg 数据 转 tfrecord 数据


    一、labelImg 的使用请自行百度~

    二、xml 转 csv

    labelImg 标注好图片后,得到的是N个xml文件;这里,我们处理一下 xml 目录,得到一个 csv 文件

    import glob
    import pandas as pd
    import xml.etree.ElementTree as ET
    
    
    def xml_to_csv(xml_dir):
        xml_list = []
        for xml_file in glob.glob(xml_dir + '/*.xml'):
            print(xml_file)
            tree = ET.parse(xml_file)
            root = tree.getroot()
            for member in root.findall('object'):
                value = (root.find('filename').text,
                         int(root.find('size')[0].text),
                         int(root.find('size')[1].text),
                         member[0].text,
                         int(member[4][0].text),
                         int(member[4][1].text),
                         int(member[4][2].text),
                         int(member[4][3].text)
                         )
                xml_list.append(value)
    
        column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
        xml_df = pd.DataFrame(xml_list, columns=column_name)
        return xml_df
    
    
    if __name__ == '__main__':
        # 输入:标注图像集后生成的 .xml 文件目录
        xml_dir = r'/path_to_xml'
        xml_df = xml_to_csv(xml_dir)
        # 输出:生成的 .csv 文件的存放位置
        csv_path = r'/path_to_output/xxx.csv'
        xml_df.to_csv(csv_path, index=None)
        print('Successfully converted xml to csv')
        print(csv_path)

    三、csv 转 tfrecord

    注:先要安装好 object_detection api,安装教程:https://www.cnblogs.com/tujia/p/13952108.html

    import os
    import json
    import pandas as pd
    from object_detection.dataset_tools import create_coco_tf_record
    
    
    def create_tfrecord(csv_path, data_dir, output_dir):
        examples = pd.read_csv(csv_path)
    
        images, annotations = [], []
        image, exists = None, []
    
        for i, row in examples.iterrows():
            if row['filename'] not in exists:
                image = {
                    'id': i,
                    'file_name': row['filename'],
                    'width': row['width'],
                    'height': row['height']
                }
                images.append(image)
                exists.append(row['filename'])
    
            annotations.append({
                'area': 0.5,
                'iscrowd': False,
                'image_id': image['id'],
                'bbox': [row['xmin'], row['ymin'], row['xmax']-row['xmin'], row['ymax']-row['ymin']],
                'category_id': 1,
                'id': i
            })
        
        groundtruth_data = {'images': images, 'annotations': annotations, 'categories': [category_index[1]]}
        annotation_file = os.path.join(output_dir, class_name + '_annotation.json')
        with open(annotation_file, 'w') as annotation_fid:
            json.dump(groundtruth_data, annotation_fid)
    
        output_path = os.path.join(output_dir, class_name + '.record')
        create_coco_tf_record._create_tf_record_from_coco_annotations(
            annotation_file,
            data_dir,
            output_path,
            False,
            2)
        print('Finish!!')
        print(output_path.replace(class_name, 'xxx') + '...')
    
    
    if __name__ == '__main__':
        class_name = 'xxx'
        category_index = {1: {'id': 1, 'name': class_name}}
        CSV_PATH = '/tf/datasets/%s.csv' % class_name
        DATA_DIR = '/tf/datasets/%s' % class_name
        OUTPUT_DIR = '/tf/object_detection/data/'
        create_tfrecord(CSV_PATH, DATA_DIR, OUTPUT_DIR)

    注:我这里只有一个类,category_index 我就直接写了,多个类的,自己修改一下

    四、验证 tfrecord 数据准确性(可视化)

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from object_detection.utils import visualization_utils as viz_utils
    from six import BytesIO
    from PIL import Image
    
    %matplotlib inline
    
    
    def load_image_into_numpy_array(img_data):
        image = Image.open(BytesIO(img_data))
        (im_width, im_height) = image.size
        return np.array(image.getdata()).reshape(
          (im_height, im_width, 3)).astype(np.uint8)
    
    
    def plot_detections(image_np,
                        boxes,
                        classes,
                        scores,
                        category_index,
                        figsize=(12, 16),
                        image_name=None):
        image_np_with_annotations = image_np.copy()
        viz_utils.visualize_boxes_and_labels_on_image_array(
            image_np_with_annotations,
            boxes,
            classes,
            scores,
            category_index,
            use_normalized_coordinates=True,
            min_score_thresh=0.8)
        if image_name:
            plt.imsave(image_name, image_np_with_annotations)
        else:
            plt.figure()
            plt.imshow(image_np_with_annotations)
    
    def get_boxes(filenames):
        raw_dataset = tf.data.TFRecordDataset(filenames)
    
        images_np = []
        gt_boxes = []
        for raw_record in raw_dataset.take(2):
            example = tf.train.Example()
            example.ParseFromString(raw_record.numpy())
            for key, item in example.features.feature.items():
                if key == 'image/encoded':
                    images_np.append(load_image_into_numpy_array(item.bytes_list.value[0]))
                #if item.float_list.value:
                    #print(key + ':', end='')
                    #print(item.float_list.value)
            gt_boxes.append(np.array([[
                    example.features.feature['image/object/bbox/ymin'].float_list.value[0],
                    example.features.feature['image/object/bbox/xmin'].float_list.value[0],
                    example.features.feature['image/object/bbox/ymax'].float_list.value[0],
                    example.features.feature['image/object/bbox/xmax'].float_list.value[0]
                ]
            ], dtype=np.float32))
        return images_np, gt_boxes
    
    
    if __name__ == '__main__':
        class_name = 'xxx'
        category_index = {1: {'id': 1, 'name': class_name}}
        filenames = ['/tf/object_detection/data/%s.record-00000-of-00001' % class_name]
        (images_np, gt_boxes) = get_boxes(filenames)
    
        # give boxes a score of 100%
        dummy_scores = np.array([1.0], dtype=np.float32)
    
        for idx in range(2):
            plot_detections(
                images_np[idx],
                gt_boxes[idx],
                np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),
                dummy_scores, category_index)
  • 相关阅读:
    后台java,前台extjs文件下载
    gridPanel可拖拽排序
    Extjs 获取输入框焦点,并选中值
    java poi对Excel文件加密
    java poi 读取有密码加密的Excel文件
    SSL 与 数字证书 的基本概念和工作原理
    splay树
    树剖版lca
    树链剖分
    kruskal重构树
  • 原文地址:https://www.cnblogs.com/tujia/p/14085045.html
Copyright © 2020-2023  润新知