将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练。
import xml.etree.ElementTree as ET import numpy as np import os import tensorflow as tf from PIL import Image classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] def convert(size, box): dw = 1./size[0] dh = 1./size[1] x = (box[0] + box[1])/2.0 y = (box[2] + box[3])/2.0 w = box[1] - box[0] h = box[3] - box[2] x = x*dw w = w*dw y = y*dh h = h*dh return [x, y, w, h] def convert_annotation(image_id): in_file = open('F:/xml/%s.xml'%(image_id)) tree = ET.parse(in_file) root = tree.getroot() size = root.find('size') w = int(size.find('width').text) h = int(size.find('height').text) bboxes = [] for i, obj in enumerate(root.iter('object')): if i > 29: break difficult = obj.find('difficult').text cls = obj.find('name').text if cls not in classes or int(difficult) == 1: continue cls_id = classes.index(cls) xmlbox = obj.find('bndbox') b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) bb = convert((w, h), b) + [cls_id] bboxes.extend(bb) if len(bboxes) < 30*5: bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5)) return np.array(bboxes, dtype=np.float32).flatten().tolist() def convert_img(image_id): image = Image.open('F:/snow leopard/test_im/%s.jpg' % (image_id)) resized_image = image.resize((416, 416), Image.BICUBIC) image_data = np.array(resized_image, dtype='float32')/255 img_raw = image_data.tobytes() return img_raw filename = os.path.join('test'+'.tfrecords') writer = tf.python_io.TFRecordWriter(filename) # image_ids = open('F:/snow leopard/test_im/%s.txt' % ( # year, year, image_set)).read().strip().split() image_ids = os.listdir('F:/snow leopard/test_im/') # print(filename) for image_id in image_ids: print (image_id) image_id = image_id.split('.')[0] print (image_id) xywhc = convert_annotation(image_id) img_raw = convert_img(image_id) example = tf.train.Example(features=tf.train.Features(feature={ 'xywhc': tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)), 'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), })) writer.write(example.SerializeToString()) writer.close()
Python读取文件夹下图片的两种方法:
import os imagelist = os.listdir('./images/') #读取images文件夹下所有文件的名字
import glob imagelist= sorted(glob.glob('./images/' + 'frame_*.png')) #读取带有相同关键字的图片名字,比上一中方法好
参考:
https://blog.csdn.net/CV_YOU/article/details/80778392
https://github.com/raytroop/YOLOv3_tf