• 对小样本进行数据增强


    针对YoloV3 中的训练数据不足的情况,考虑数据增强的方式,同时改变原始数据标注的坐标。

      1 import xml.etree.ElementTree as ET
      2 import os
      3 import numpy as np
      4 from PIL import Image
      5 import shutil
      6 
      7 import imgaug as ia
      8 from imgaug import augmenters as iaa
      9 
     10 
     11 ia.seed(1)
     12 
     13 def read_xml_annotation(root, image_id):
     14     in_file = open(os.path.join(root, image_id))
     15     tree = ET.parse(in_file)
     16     root = tree.getroot()
     17     bndboxlist = []
     18 
     19     for object in root.findall('object'):  # 找到root节点下的所有country节点
     20         bndbox = object.find('bndbox')  # 子节点下节点rank的值
     21 
     22         xmin = int(bndbox.find('xmin').text)
     23         xmax = int(bndbox.find('xmax').text)
     24         ymin = int(bndbox.find('ymin').text)
     25         ymax = int(bndbox.find('ymax').text)
     26         # print(xmin,ymin,xmax,ymax)
     27         bndboxlist.append([xmin, ymin, xmax, ymax])
     28         # print(bndboxlist)
     29 
     30     bndbox = root.find('object').find('bndbox')
     31     return bndboxlist
     32 
     33 
     34 # (506.0000, 330.0000, 528.0000, 348.0000) -> (520.4747, 381.5080, 540.5596, 398.6603)
     35 def change_xml_annotation(root, image_id, new_target):
     36     new_xmin = new_target[0]
     37     new_ymin = new_target[1]
     38     new_xmax = new_target[2]
     39     new_ymax = new_target[3]
     40 
     41     in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
     42     tree = ET.parse(in_file)
     43     xmlroot = tree.getroot()
     44     object = xmlroot.find('object')
     45     bndbox = object.find('bndbox')
     46     xmin = bndbox.find('xmin')
     47     xmin.text = str(new_xmin)
     48     ymin = bndbox.find('ymin')
     49     ymin.text = str(new_ymin)
     50     xmax = bndbox.find('xmax')
     51     xmax.text = str(new_xmax)
     52     ymax = bndbox.find('ymax')
     53     ymax.text = str(new_ymax)
     54     tree.write(os.path.join(root, str("%06d" % (str(id) + '.xml'))))
     55 
     56 
     57 def change_xml_list_annotation(root, image_id, new_target, saveroot, id):
     58     in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
     59     tree = ET.parse(in_file)
     60     elem = tree.find('filename')
     61     elem.text = (str("%06d" % int(id)) + '.bmp')
     62     xmlroot = tree.getroot()
     63     index = 0
     64 
     65     for object in xmlroot.findall('object'):  # 找到root节点下的所有country节点
     66         bndbox = object.find('bndbox')  # 子节点下节点rank的值
     67 
     68         # xmin = int(bndbox.find('xmin').text)
     69         # xmax = int(bndbox.find('xmax').text)
     70         # ymin = int(bndbox.find('ymin').text)
     71         # ymax = int(bndbox.find('ymax').text)
     72 
     73         new_xmin = new_target[index][0]
     74         new_ymin = new_target[index][1]
     75         new_xmax = new_target[index][2]
     76         new_ymax = new_target[index][3]
     77 
     78         xmin = bndbox.find('xmin')
     79         xmin.text = str(new_xmin)
     80         ymin = bndbox.find('ymin')
     81         ymin.text = str(new_ymin)
     82         xmax = bndbox.find('xmax')
     83         xmax.text = str(new_xmax)
     84         ymax = bndbox.find('ymax')
     85         ymax.text = str(new_ymax)
     86 
     87         index = index + 1
     88 
     89     tree.write(os.path.join(saveroot, str("%06d" % int(id)) + '.xml'))
     90 
     91 
     92 def mkdir(path):
     93     # 去除首位空格
     94     path = path.strip()
     95     # 去除尾部 \ 符号
     96     path = path.rstrip("/")
     97     # 判断路径是否存在
     98     # 存在     True
     99     # 不存在   False
    100     isExists = os.path.exists(path)
    101     # 判断结果
    102     if not isExists:
    103         # 如果不存在则创建目录
    104         # 创建目录操作函数
    105         os.makedirs(path)
    106         print(path + ' 创建成功')
    107         return True
    108     else:
    109         # 如果目录存在则不创建,并提示目录已存在
    110         print(path + ' 目录已存在')
    111         return False
    112 
    113 
    114 if __name__ == "__main__":
    115 
    116     IMG_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/images/"
    117     XML_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/Anotations/"
    118 
    119     AUG_XML_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/Anotations_Augment/"  # 存储增强后的XML文件夹路径
    120     try:
    121         shutil.rmtree(AUG_XML_DIR)
    122     except FileNotFoundError as e:
    123         a = 1
    124     mkdir(AUG_XML_DIR)
    125 
    126     AUG_IMG_DIR = "E:/codePro/python_pro/Yolo-Faster-XL_QR_pro/QR_Data_Augmentation/one/images_Augment/"  # 存储增强后的影像文件夹路径
    127     try:
    128         shutil.rmtree(AUG_IMG_DIR)
    129     except FileNotFoundError as e:
    130         a = 1
    131     mkdir(AUG_IMG_DIR)
    132 
    133     AUGLOOP = 10  # 每张影像增强的数量
    134 
    135     boxes_img_aug_list = []
    136     new_bndbox = []
    137     new_bndbox_list = []
    138 
    139     # 影像增强
    140     seq = iaa.Sequential([
    141         iaa.Flipud(0.5),  # vertically flip 20% of all images
    142         iaa.Fliplr(0.5),  # 镜像
    143         iaa.Multiply((1.2, 1.5)),  # change brightness, doesn't affect BBs
    144         iaa.GaussianBlur(sigma=(0, 2.0)),  # iaa.GaussianBlur(0.5),
    145         iaa.Affine(
    146             translate_px={"x": 15, "y": 15},
    147             scale=(0.8, 0.95),
    148             rotate=(-30, 30)
    149         )  # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs
    150     ])
    151 
    152 
    153     for root, sub_folders, files in os.walk(XML_DIR):
    154 
    155         nameCnt =0 
    156 
    157         for name in files:
    158 
    159             bndbox = read_xml_annotation(XML_DIR, name)
    160             shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR)
    161             shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.bmp'), AUG_IMG_DIR)
    162 
    163             for epoch in range(AUGLOOP):
    164                 seq_det = seq.to_deterministic()  # 保持坐标和图像同步改变,而不是随机
    165                 # 读取图片
    166                 img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.bmp'))
    167                 # sp = img.size
    168                 img = np.asarray(img)
    169                 # bndbox 坐标增强
    170                 for i in range(len(bndbox)):
    171                     bbs = ia.BoundingBoxesOnImage([
    172                         ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),
    173                     ], shape=img.shape)
    174 
    175                     bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]
    176                     boxes_img_aug_list.append(bbs_aug)
    177 
    178                     # new_bndbox_list:[[x1,y1,x2,y2],...[],[]]
    179                     n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))
    180                     n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))
    181                     n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))
    182                     n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))
    183                     if n_x1 == 1 and n_x1 == n_x2:
    184                         n_x2 += 1
    185                     if n_y1 == 1 and n_y2 == n_y1:
    186                         n_y2 += 1
    187                     if n_x1 >= n_x2 or n_y1 >= n_y2:
    188                         print('error', name)
    189                     new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])
    190                 # 存储变化后的图片
    191                 image_aug = seq_det.augment_images([img])[0]
    192                 path = os.path.join(AUG_IMG_DIR,
    193                                     str("%06d" % int(float((len(files) + nameCnt + epoch * 250)))) + '.bmp')
    194                 image_auged = bbs.draw_on_image(image_aug, thickness=0)
    195                 Image.fromarray(image_auged).save(path)
    196 
    197                 # 存储变化后的XML
    198                 change_xml_list_annotation(XML_DIR, name[:-4], new_bndbox_list, AUG_XML_DIR,
    199                                            len(files) +nameCnt + epoch * 250)
    200                 print(str("%06d" % (len(files) +nameCnt + epoch * 250)) + '.bmp')
    201                 new_bndbox_list = []
    202 
    203                 nameCnt +=1
    View Code

      参考:https://blog.csdn.net/weixin_45829462/article/details/105951949

  • 相关阅读:
    PAT 1006 Sign In and Sign Out
    PAT 1004. Counting Leaves
    JavaEE开发环境安装
    NoSql数据库探讨
    maven的配置
    VMWARE 下使用 32位 Ubuntu Linux ,不能给它分配超过3.5G 内存?
    XCODE 4.3 WITH NO GCC?
    在苹果虚拟机上跑 ROR —— Ruby on Rails On Vmware OSX 10.7.3
    推荐一首让人疯狂的好歌《Pumped Up Kicks》。好吧,顺便测下博客园可以写点无关技术的帖子吗?
    RUBY元编程学习之”编写你的第一种领域专属语言“
  • 原文地址:https://www.cnblogs.com/zhaopengpeng/p/16003687.html
Copyright © 2020-2023  润新知