• yolo v2使用总结


    以下都是基于yolo v2版本的,对于现在的v3版本,可以先clone下来,再git checkout回v2版本。

    玩了三四个月的yolo后发现数值相当不稳定,yolo只能用来小打小闹了。

    v2训练的权重用v3做预测,结果不一样。

    我的环境是 window 10 + cuda9.0 + opencv 3.4.0 + VS2015

    先在这个地方下源文件:https://github.com/AlexeyAB/darknet

    下好后,先打开用文本编辑器打开 darknet.vcxproj,将两处 cuda9.1 改成 cuda9.0

    还要拷贝opencv的两个dll到 x64 下

    用 VS2015 打开 darknet.sln,生成一下,在 darknet-masteruilddarknetx64 下面得到一个 darknet.exe

    这个时候已经可以用 训练好的模型 对 训练时的类别 做预测了
    当然,要预先下好对应的 weights 文件
    darknet.exe detector test cfg/combine9k.data yolo9000.cfg yolo9000.weights data/dog.jpg

    先准备好训练图片!

    一定要先做好文件重命名工作!不然后面想添加、修改、删减样本都很痛苦。

    然后再用 windows 版的 labelImg 做标注
    然后修改相关配置文件,然后就可以开始训练了。

    训练前下好一个 darknet19_448.conv.23 文件

    训练命令如下
    darknet.exe detector train cfg/voc.data cfg/yolo-voc.cfg cfg/darknet19_448.conv.23
    要准备好.data 和 .cfg 文件 以及 训练数据集

    训练时如果说找不到 txt 文件,直接把 txt 文件拷贝到image文件夹下

    最后会得到一个自己的 weights 文件可以用来预测自己的类别

     

    2018年8月24日08:58:19

    总结一个比较完整的流程出来,大致有以下几个步骤:

    1,不停的图片采集,以及不停的对新采集的图片重命名。

    因为采集到的图片名称可能是以秒命名的,也有可能是按日期命名的,或者有其他命名规范。

    后续标注和训练完测试的时候,如果发现某些样本图片不理想要剔除或跳过的时候,名字不易找就比较麻烦。

    2,图片标注的技巧,实际标注的xml数量是少于采集到的图片数量的,因为有些图片拍的角度或光照不理想。

    3,离线数据增强

    yolo本身有一些在线的数据增强,然而没有深入阅读过代码的话很难改的动,而且提供的增强手段有限。所以做了个

    简单的离线数据增强。

    4,训练数据预处理,修改cfg文件等。

    5,训练和测试

    6,在opencv中调用,C++和python两种版本。以及批量抠图。

    =====================================================

    1,重命名    rename.py

     1 import os
     2 
     3 def getFilenames(filepath):
     4     '''
     5     得到一个文件夹下所有的文件名,不包含后缀, 忽略文件夹
     6     '''
     7     filenames = []
     8     
     9     for file in os.listdir(filepath):
    10         pt = os.path.join(filepath, file)
    11         
    12         if( os.path.isfile(pt) ):
    13             filename = os.path.splitext(file)[0]
    14             filenames.append(filename)
    15     return filenames
    16 
    17 filepath = "origin_img"
    18 filenames = getFilenames(filepath)
    19 print(filenames)
    20 
    21 
    22 # 只对文件夹中新增加的图片重命名,增量式重命名。所以要取得图片文件名中最大的数字
    23 def get_max_num(filenames):
    24     max_num = 0
    25     
    26     for name in filenames:
    27         if( name.isdigit() and int(name) < 10000):
    28             if(int(name) > max_num):
    29                 max_num = int(name)
    30     
    31     return max_num
    32 
    33 print("max num:", get_max_num(filenames))
    34 
    35 renameCount = get_max_num(filenames)+1
    36 
    37 #renameCount = 1
    38 
    39 for file in os.listdir(filepath):
    40     if( os.path.isfile( os.path.join(filepath, file) ) ): # 如果是文件
    41         filename = os.path.splitext(file)[0] # 取得文件名
    42         
    43         if( not filename.isdigit() ): # 如果文件名不是数字,则重命名
    44             print("rename count 1:", renameCount)
    45             os.rename(os.path.join(filepath, file), os.path.join(filepath, str('%03d'%renameCount)+".jpeg"))
    46             renameCount+=1
    47             
    48         if( filename.isdigit() and int(filename) > 10000 ):
    49             print("rename count 2:", renameCount)
    50             os.rename(os.path.join(filepath, file), os.path.join(filepath, str('%03d'%renameCount)+".jpeg"))
    51             renameCount+=1
    View Code

    2,图片标注技巧

    用labelimg,尽量从特征的角度考虑,把目标物体最明显的特征框进去,跟程序实际工作情况差距太大的图片

    就不要标注了,如果特征不是太明显的也标了,那样本数量就要上去。

    标注的xml是少于图片数量的,需要进一步将标注过的图片拿出来。

     1 import os
     2 import shutil
     3 
     4 def getFilenames(filepath):
     5     '''得到一个文件夹下所有的文件名,不包含后缀
     6     '''
     7     filelist = os.listdir(filepath)
     8 
     9     filenames = []
    10     
    11     for files in filelist:
    12         filename = os.path.splitext(files)[0]
    13     #    print(files)
    14     #    print(filename)
    15         filenames.append(filename)
    16     return filenames
    17 
    18 
    19 xmlpath = 'xml'  # 这个是标注过后xml所在文件夹
    20 imgpath = 'img'  # 这个是标注时图片所在文件夹,图片数目多于xml
    21 
    22 img_write_path = 'img_less'  # 把标注过的图像拷贝到这个文件夹
    23 
    24 filenames = getFilenames(xmlpath)
    25 
    26 for i in range(len(filenames)):
    27     filename = filenames[i]
    28 #    print(filename)
    29     
    30     jpgpath = imgpath + "/" + str(filename) + ".jpeg"
    31 #    print(jpgpath)
    32     
    33     jpg_wrt_path = img_write_path + "/" + str(filename) + ".jpg"
    34     shutil.copy(jpgpath, jpg_wrt_path)
    View Code

    3,离线数据增强

    有些数据增强要修改xml,比如水平翻转,旋转,裁剪,有些不要,比如对颜色、光照做扰动等。

    本来写了水平翻转和随机裁剪,后来想想yolo把图片resize到416,随机裁剪没什么卵用。

    注意数据增强的时候不要搞出一堆人为的特征让yolo去学,要合理的增强。

    这里只放一个水平翻转的功能出来。

    util.py

      1 import numpy as np
      2 import cv2
      3 import matplotlib.pyplot as plt
      4 import random
      5 import os
      6 
      7 
      8 def showimg(img):
      9     channelNum = len(img.shape)
     10     
     11     if channelNum == 3:
     12         fig = plt.subplots(1),plt.imshow( cv2.cvtColor(img,  cv2.COLOR_BGR2RGB)  )
     13     if channelNum == 2:
     14         fig = plt.subplots(1),plt.imshow( img  )
     15         
     16 
     17 def scaleimg(img, scale = 1.0):
     18     H, W, C = img.shape
     19     size = (int(scale*W), int(scale*H))  
     20     img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
     21     del H, W, C, size, scale
     22     return img.copy()
     23 
     24 
     25 # img = rotateimg(image, angle)
     26 def rotateimg(image, angle, center=None, scale=1.0):
     27     # 获取图像尺寸
     28     (h, w) = image.shape[:2]
     29 
     30     # 若未指定旋转中心,则将图像中心设为旋转中心
     31     if center is None: 
     32         center = (w / 2, h / 2)
     33 
     34     # 执行旋转
     35     M = cv2.getRotationMatrix2D(center, angle, scale) # 给的角度为正的时候,则逆时针旋转
     36     rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC)
     37 #    rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
     38 
     39     return rotated  # 返回旋转后的图像, angle是角度制,不是弧度制
     40 
     41 
     42 '''
     43 1, 读取xml返回结果
     44     输入:CLASS_NAMES元组   xml路径
     45     返回:(H, W, boxes)   boxes是一个二维np数组,6列分别为
     46           id classid xmin xmax ymin ymax
     47            0       1    2    3    4    5
     48           
     49 2,  将boxes  CLASS_NAMES  H W 信息写入xml
     50     输入:boxes  CLASS_NAMES  H W   xml路径
     51     输出:硬盘上的一个xml
     52     
     53 3, 根据 img,boxes数组,class_names,画出一个图来
     54     输入:img,boxes,class_names
     55 '''
     56 
     57 import xml.etree.ElementTree as ET
     58 import numpy as np
     59 
     60 #CLASS_NAMES = ('person', 'dog')  # 下标从0开始,这里可以没有顺序,最好有顺序
     61 
     62 #          id classid xmin xmax ymin ymax
     63 #           0       1    2    3    4    5
     64 def xml2boxes(xmlpath, CLASS_NAMES):
     65     print("xmlpath:", xmlpath)
     66     
     67     cls_to_idx = dict( zip( CLASS_NAMES            , range(len(CLASS_NAMES)) ))
     68     idx_to_cls = dict( zip( range(len(CLASS_NAMES)), CLASS_NAMES             ))
     69     
     70 #    print(cls_to_idx)
     71 #    print(idx_to_cls)
     72     
     73     annotations = ET.parse(xmlpath)
     74     # 获得 HWC
     75     size = annotations.find('size')
     76     W = int(size.find('width').text)
     77     H = int(size.find('height').text)
     78     C = int(size.find('depth').text)
     79     # 获得类别和具体坐标
     80     bbox = list()
     81     count = 1
     82     for obj in annotations.iter('object'): # 提取 xml文件中的信息
     83         line = []
     84         bndbox_anno = obj.find('bndbox')
     85         # xmin等从 1 开始计数
     86         tmp = map(int, [bndbox_anno.find('xmin').text, 
     87                         bndbox_anno.find('xmax').text,
     88                         bndbox_anno.find('ymin').text, 
     89                         bndbox_anno.find('ymax').text])
     90         tmp = list(tmp) # 1 x 4
     91         
     92         name = obj.find('name').text.lower().strip()
     93     
     94         line.append(count)
     95         line.append(cls_to_idx[name])
     96         line.append(tmp[0])
     97         line.append(tmp[1])
     98         line.append(tmp[2])
     99         line.append(tmp[3])
    100         count = count + 1
    101 #        print(line)
    102         bbox.append( line )
    103     
    104     boxes = np.stack(bbox).astype(np.int32)
    105     return boxes, H, W
    106 
    107 #boxes, H, W = xml2boxes("1.xml", CLASS_NAMES)
    108 #print("boxes:
    ", boxes)
    109 # 对只有一个类别的时候,CLASS_NAMES要在后面加一个字符串
    110 # 比如 CLASS_NAMES = ("apple", "xxxx") 这是个bug,还没修
    111 
    112 from lxml.etree import Element, SubElement, tostring
    113 from xml.dom.minidom import parseString
    114 
    115 
    116 ######################################################
    117 # boxes2xml_labelImg(boxes, CLASS_NAMES, H, W, xmlpath, wrtin_img_folder_name, imgName, img_fullpath)
    118 def boxes2xml_labelImg(boxes, CLASS_NAMES, H, W, xmlpath, wrtin_img_folder_name,
    119                        imgName, img_fullpath):
    120     '''
    121     这是一个labelImg可以查看的版本
    122     
    123     这个时候要求 CLASS_NAMES 是有顺序的,和boxes里头的第二列
    124     的类别id要一一对应
    125     '''
    126     cls_to_idx = dict( zip( CLASS_NAMES            , range(len(CLASS_NAMES)) ))
    127     idx_to_cls = dict( zip( range(len(CLASS_NAMES)), CLASS_NAMES             ))
    128     
    129     
    130     node_annotation = Element('annotation')
    131     #################################################
    132     node_folder   = SubElement(node_annotation, 'folder')
    133     node_filename = SubElement(node_annotation, 'filename')
    134     node_path     = SubElement(node_annotation, 'path')
    135     
    136     node_source   = SubElement(node_annotation, 'source')
    137     node_database   = SubElement(node_source, 'database')
    138     
    139     
    140     node_folder.text = wrtin_img_folder_name  # 这个是定死的,赋值一次就不会变了
    141     node_filename.text = imgName              # 图片的文件名,不包含后缀
    142     node_path.text = img_fullpath                  # 随着文件名变化
    143     
    144     node_database.text = "Unknown"
    145     
    146     
    147     node_size     = SubElement(node_annotation, 'size')
    148     #################################################
    149     # node_size
    150     node_width   = SubElement(node_size, 'width')
    151     node_height  = SubElement(node_size, 'height')
    152     node_depth   = SubElement(node_size, 'depth')
    153     
    154     node_width.text  = str(W)
    155     node_height.text = str(H)
    156     node_depth.text  = str(3) # 默认是彩色
    157     #################################################
    158     node_segmented   = SubElement(node_annotation, 'segmented')
    159     node_segmented.text = "0"
    160     #################################################
    161     
    162     # node_object  若干    要循环
    163     for i in range(boxes.shape[0]):
    164         node_object = SubElement(node_annotation, 'object')
    165         classid = boxes[i, 1]
    166     #    print(idx_to_cls[classid])
    167         node_name   = SubElement(node_object, 'name')
    168         node_name.text = idx_to_cls[classid]
    169         
    170         node_pose   = SubElement(node_object, 'pose')
    171         node_truncated   = SubElement(node_object, 'truncated')
    172         node_Difficult   = SubElement(node_object, 'Difficult')
    173         
    174         node_pose.text = "Unspecified"
    175         node_truncated.text = "1"
    176         node_Difficult.text = "0"
    177         
    178         
    179         node_bndbox = SubElement(node_object, 'bndbox')
    180     
    181         node_xmin = SubElement(node_bndbox, 'xmin')
    182         node_ymin = SubElement(node_bndbox, 'ymin')
    183         node_xmax = SubElement(node_bndbox, 'xmax')
    184         node_ymax = SubElement(node_bndbox, 'ymax')
    185         
    186         node_xmin.text = str(boxes[i, 2])
    187         node_xmax.text = str(boxes[i, 3])
    188         node_ymin.text = str(boxes[i, 4])
    189         node_ymax.text = str(boxes[i, 5])
    190     
    191     ###################
    192     xml = tostring(node_annotation, pretty_print=True)  #格式化显示,该换行的换行
    193     dom = parseString(xml)
    194     
    195     test_string = xml.decode('utf-8')
    196     #print('test:
    ', test_string)
    197     
    198     with open(xmlpath, "w") as text_file:
    199         text_file.write(test_string)
    200         
    201 ######################################################
    202 def drawboxes(imgpath, boxes, CLASS_NAMES):
    203     import matplotlib.pyplot as plt
    204     import matplotlib.patches as patches
    205     import cv2
    206     
    207     cls_to_idx = dict( zip( CLASS_NAMES            , range(len(CLASS_NAMES)) ))
    208     idx_to_cls = dict( zip( range(len(CLASS_NAMES)), CLASS_NAMES             ))
    209     
    210     if isinstance(imgpath, str):
    211         img = cv2.imread(imgpath)
    212         img = cv2.cvtColor(img,  cv2.COLOR_BGR2RGB)
    213     if isinstance(imgpath, np.ndarray):
    214         img = imgpath
    215     
    216     fig, ax = plt.subplots(1)
    217     
    218     for i in range(boxes.shape[0]):
    219         bndbox = list(boxes[i,:])
    220         x = bndbox[2]
    221         y = bndbox[4]
    222         w = bndbox[3] - bndbox[2]
    223         h = bndbox[5] - bndbox[4]
    224         rect = patches.Rectangle( (x,y),w,h, linewidth=1,edgecolor='yellow',facecolor='none')
    225         ax.add_patch(rect)
    226         name = idx_to_cls[boxes[i, 1]]
    227         ax.text(x-5, y-5, name, style='italic', color='yellow', fontsize=12)
    228         
    229     ax.imshow(img)
    230 
    231 #drawboxes("1.jpg", boxes, CLASS_NAMES)
    232 
    233 ##################################
    234 
    235 
    236 
    237 def getFilenames(filepath):
    238     '''得到一个文件夹下所有的文件名,不包含后缀
    239     '''
    240     filelist = os.listdir(filepath)
    241 
    242     filenames = []
    243     
    244     for files in filelist:
    245         filename = os.path.splitext(files)[0]
    246     #    print(files)
    247     #    print(filename)
    248         filenames.append(filename)
    249     return filenames
    250 
    251 
    252 def fliplr_boxes(boxes, W):
    253     ''' 对boxes做水平翻转'''
    254     boxes_copy = boxes.copy()
    255     
    256     xmin = boxes[:, 2].copy()
    257     xmax = boxes[:, 3].copy()
    258     boxes_copy[:, 3] = W - 1 - xmin # 注意这里不是 2,3 是 3,2 不然xmin会大于xmax
    259     boxes_copy[:, 2] = W - 1 - xmax
    260     return boxes_copy
    View Code

    main.py

     1 import os
     2 import cv2
     3 import numpy as np
     4 import random
     5 import matplotlib.pyplot as plt
     6 import matplotlib.patches as patches
     7 
     8 from util import *
     9 
    10 img_read_path = "img_less"
    11 xml_read_path = "xml"
    12 
    13 img_write_path = "fliped_img" # 图片和xml水平翻转后的写入文件夹
    14 xml_write_path = "fliped_xml"
    15 
    16 
    17 
    18 filenames = getFilenames(xml_read_path)
    19 
    20 CLASS_NAMES = ('person', 'aa')  # 这里有个bug懒得改,一个类别时也要写两个进去
    21 
    22 count = 201
    23 
    24 wrtin_img_folder_name = "fliped_img"
    25 
    26 
    27 for i in range(len(filenames)):
    28     name = filenames[i]
    29     
    30     imgname = img_read_path + "/" + str(name) + ".jpg"
    31     img = cv2.imread(imgname)
    32     
    33     xmlname = xml_read_path + "/" + str(name) + ".xml"
    34     boxes, H, W = xml2boxes(xmlname, CLASS_NAMES)
    35 #    print("xmlname:", xmlname)
    36     
    37     H,W,C = img.shape
    38     ##############################
    39     fliped_boxes = fliplr_boxes(boxes, W)
    40     fliped_img = cv2.flip(img, 1)
    41     ##############################
    42     FileName = str(count)
    43     
    44     jpgpath = img_write_path + "/" + FileName + ".jpg"
    45     cv2.imwrite(jpgpath, fliped_img)
    46     
    47     xmlpath = xml_write_path + "/" + FileName + ".xml"
    48     boxes2xml_labelImg(fliped_boxes, CLASS_NAMES, H, W, xmlpath, wrtin_img_folder_name, FileName, jpgpath)
    49    
    50     count = count + 1
    View Code

    4,训练数据预处理,修改cfg文件

    下面这两个文件一定要执行,我从别的地方copy过来的。

    trans1.py

     1 import os
     2 import shutil
     3 
     4 savepath = os.getcwd()
     5 
     6 img_path = savepath + "/img_less"   # 存放训练图片的文件夹名
     7 xml_path = savepath + "/xml"   # 总共标注了 X 张图片
     8 
     9 val_num = 10   #验证集数量,可修改
    10 
    11 # 下面新建了 4 个目录
    12 validateImage_path = savepath + "/validateImage";
    13 trainImage_path    = savepath + "/trainImage";
    14 
    15 if os.path.exists(validateImage_path)== False:
    16     os.mkdir(validateImage_path)
    17 if os.path.exists(trainImage_path) == False:
    18     os.mkdir(trainImage_path)
    19     
    20 
    21 validateImageXML_path = savepath + "/validateImageXML"
    22 trainImageXML_path    = savepath + "/trainImageXML"
    23 
    24 if os.path.exists(validateImageXML_path)== False:
    25     os.mkdir(validateImageXML_path)
    26 if os.path.exists(trainImageXML_path) == False:
    27     os.mkdir(trainImageXML_path)
    28 #=================================================
    29 
    30 filelist = os.listdir(xml_path)  # 以xml文件夹中的数量为标准
    31 
    32 count = 0
    33 for files in filelist:
    34     filename = os.path.splitext(files)[0]  # 文件名
    35     
    36     origin_jpg_name = os.path.join(img_path, filename + '.jpg')
    37     
    38     validateImage_jpg_name = os.path.join(validateImage_path, filename + '.jpg')
    39     trainImage_jpg_name    = os.path.join(trainImage_path,    filename + '.jpg')
    40 #    print(validateImage_jpg_name)
    41     
    42     
    43     if count < val_num:
    44         shutil.copy(origin_jpg_name, validateImage_jpg_name); # 拷贝 validate 图片
    45         
    46         xml_olddir = os.path.join(xml_path,              filename + ".xml")
    47         xml_newdir = os.path.join(validateImageXML_path, filename + ".xml")
    48         
    49         shutil.copyfile(xml_olddir, xml_newdir) # 拷贝 validate xml文件
    50     else:
    51         shutil.copy(origin_jpg_name, trainImage_jpg_name)
    52         
    53         xml_olddir = os.path.join(xml_path,           filename + ".xml")
    54         xml_newdir = os.path.join(trainImageXML_path, filename + ".xml")
    55         
    56         shutil.copyfile(xml_olddir, xml_newdir)
    57         
    58     count=count+1;
    59 
    60 validate_txtpath = savepath + "/validateImageId.txt"
    61 train_txtpath    = savepath + "/trainImageId.txt"
    62 
    63 
    64 def listname(path, idtxtpath):
    65     filelist = os.listdir(path)  # 该文件夹下所有的文件(包括文件夹)
    66     f = open(idtxtpath, 'w')
    67     
    68     for files in filelist:  # 遍历所有文件
    69         Olddir = os.path.join(path, files)  # 原来的文件路径
    70         if os.path.isdir(Olddir):  # 如果是文件夹则跳过
    71             continue
    72         filename = os.path.splitext(files)[0]  # 文件名
    73 
    74         f.write(filename)
    75         f.write('
    ')
    76     f.close()
    77     
    78 
    79 listname(validateImage_path, validate_txtpath)
    80 listname(trainImage_path,    train_txtpath)
    View Code

    trans2.py

     1 import xml.etree.ElementTree as ET
     2 import pickle
     3 import string
     4 import os
     5 import shutil
     6 from os import listdir, getcwd
     7 from os.path import join
     8 
     9 sets=[('2012', 'train')]
    10 
    11 classes = ["person"]  
    12 
    13 
    14 def convert(size, box):
    15     dw = 1./size[0]
    16     dh = 1./size[1]
    17     x = (box[0] + box[1])/2.0
    18     y = (box[2] + box[3])/2.0
    19     w = box[1] - box[0]
    20     h = box[3] - box[2]
    21     x = x*dw
    22     w = w*dw
    23     y = y*dh
    24     h = h*dh
    25     return (x,y,w,h)
    26 
    27 def convert_annotation(image_id,flag,savepath):
    28     if flag == 0:
    29         in_file = open(savepath+'/trainImageXML/%s.xml' % (image_id))
    30         labeltxt = savepath+'/trainImageLabelTxt';
    31         if os.path.exists(labeltxt) == False:
    32             os.mkdir(labeltxt);
    33         out_file = open(savepath+'/trainImageLabelTxt/%s.txt' % (image_id), 'w')
    34         tree = ET.parse(in_file)
    35         root = tree.getroot()
    36         size = root.find('size')
    37         w = int(size.find('width').text)
    38         h = int(size.find('height').text)
    39     elif flag == 1:
    40         in_file = open(savepath+'/validateImageXML/%s.xml' % (image_id))
    41         labeltxt = savepath + '/validateImageLabelTxt';
    42         if os.path.exists(labeltxt) == False:
    43             os.mkdir(labeltxt);
    44         out_file = open(savepath+'/validateImageLabelTxt/%s.txt' % (image_id), 'w')
    45         tree = ET.parse(in_file)
    46         root = tree.getroot()
    47         size = root.find('size')
    48         w = int(size.find('width').text)
    49         h = int(size.find('height').text)
    50 
    51 
    52 
    53     for obj in root.iter('object'):
    54 #        difficult = obj.find('difficult').text
    55         cls = obj.find('name').text
    56 #        if cls not in classes or int(difficult) == 1:
    57 #            continue
    58         cls_id = classes.index(cls)
    59         xmlbox = obj.find('bndbox')
    60         b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
    61         bb = convert((w,h), b)
    62         out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '
    ')
    63 
    64 wd = getcwd()
    65 
    66 for year, image_set in sets:
    67     savepath = os.getcwd();
    68     idtxt = savepath + "/validateImageId.txt";
    69     pathtxt = savepath + "/validateImagePath.txt";
    70     
    71     image_ids = open(idtxt).read().strip().split()
    72     list_file = open(pathtxt, 'w')
    73     
    74     s = 'xefxbbxbf'
    75     for image_id in image_ids:
    76         nPos = image_id.find(s)
    77         if nPos >= 0:
    78             image_id = image_id[3:]
    79         list_file.write('%s/validateImage/%s.jpg
    ' % (wd, image_id))
    80         print(image_id)
    81         convert_annotation(image_id, 1, savepath)
    82     list_file.close()
    83 
    84     idtxt = savepath + "/trainImageId.txt";
    85     pathtxt = savepath + "/trainImagePath.txt" ;
    86     image_ids = open(idtxt).read().strip().split()
    87     list_file = open(pathtxt, 'w')
    88     s = 'xefxbbxbf'
    89     for image_id in image_ids:
    90         nPos = image_id.find(s)
    91         if nPos >= 0:
    92            image_id = image_id[3:]
    93         list_file.write('%s/trainImage/%s.jpg
    '%(wd,image_id))
    94         print(image_id)
    95         convert_annotation(image_id,0,savepath)
    96     list_file.close()
    View Code

    训练的时候,编译好了,在windows下生成了darknet.exe,ubuntu下生成了可执行文件darknet。

    然后在darknet.exe同级目录新建一个文件夹叫(比如训练做行人识别)train_person,在该文件夹下新建一个

    backup文件夹,新建一个data文件夹,把数据增强过的图片文件夹img_less和xml拷贝到data下面,将trans1.py

    和trans2.py也放到data文件夹下,然后先执行trans1.py,再执行trans2.py。然后把生成的txt拷贝到训练图片目录下。

    至于修改cfg文件,看其他的博客吧。我就提一下修改那个anchors,faster-rcnn中有9个比例确定的anchor。yolo中则是

    统计了样本中的标注框然后做聚类。用了效果不错我才修改的。

      1 # coding=utf-8
      2 # k-means ++ for YOLOv2 anchors
      3 # 通过k-means ++ 算法获取YOLOv2需要的anchors的尺寸
      4 import numpy as np
      5 
      6 # 定义Box类,描述bounding box的坐标
      7 class Box():
      8     def __init__(self, x, y, w, h):
      9         self.x = x
     10         self.y = y
     11         self.w = w
     12         self.h = h
     13 
     14 
     15 # 计算两个box在某个轴上的重叠部分
     16 # x1是box1的中心在该轴上的坐标
     17 # len1是box1在该轴上的长度
     18 # x2是box2的中心在该轴上的坐标
     19 # len2是box2在该轴上的长度
     20 # 返回值是该轴上重叠的长度
     21 def overlap(x1, len1, x2, len2):
     22     len1_half = len1 / 2
     23     len2_half = len2 / 2
     24 
     25     left = max(x1 - len1_half, x2 - len2_half)
     26     right = min(x1 + len1_half, x2 + len2_half)
     27 
     28     return right - left
     29 
     30 
     31 # 计算box a 和box b 的交集面积
     32 # a和b都是Box类型实例
     33 # 返回值area是box a 和box b 的交集面积
     34 def box_intersection(a, b):
     35     w = overlap(a.x, a.w, b.x, b.w)
     36     h = overlap(a.y, a.h, b.y, b.h)
     37     if w < 0 or h < 0:
     38         return 0
     39 
     40     area = w * h
     41     return area
     42 
     43 
     44 # 计算 box a 和 box b 的并集面积
     45 # a和b都是Box类型实例
     46 # 返回值u是box a 和box b 的并集面积
     47 def box_union(a, b):
     48     i = box_intersection(a, b)
     49     #print a.w,a.h,b.w,b.h
     50     u = a.w * a.h + b.w * b.h - i
     51     return u
     52 
     53 
     54 # 计算 box a 和 box b 的 iou
     55 # a和b都是Box类型实例
     56 # 返回值是box a 和box b 的iou
     57 def box_iou(a, b):
     58     #print box_union(a, b)
     59     return box_intersection(a, b) / box_union(a, b)
     60 
     61 
     62 # 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响
     63 # boxes是所有bounding boxes的Box对象列表
     64 # n_anchors是k-means的k值
     65 # 返回值centroids 是初始化的n_anchors个centroid
     66 def init_centroids(boxes,n_anchors):
     67     centroids = []
     68     boxes_num = len(boxes)
     69 
     70     centroid_index = np.random.choice(boxes_num, 1)
     71     centroids.append(boxes[centroid_index])
     72 
     73     print(centroids[0].w,centroids[0].h)
     74 
     75     for centroid_index in range(0,n_anchors-1):
     76 
     77         sum_distance = 0
     78         distance_thresh = 0
     79         distance_list = []
     80         cur_sum = 0
     81 
     82         for box in boxes:
     83             min_distance = 1
     84             for centroid_i, centroid in enumerate(centroids):
     85                 distance = (1 - box_iou(box, centroid))
     86                 if distance < min_distance:
     87                     min_distance = distance
     88             sum_distance += min_distance
     89             distance_list.append(min_distance)
     90 
     91         distance_thresh = sum_distance*np.random.random()
     92 
     93         for i in range(0,boxes_num):
     94             cur_sum += distance_list[i]
     95             if cur_sum > distance_thresh:
     96                 centroids.append(boxes[i])
     97                 print(boxes[i].w, boxes[i].h)
     98                 break
     99 
    100     return centroids
    101 
    102 
    103 # 进行 k-means 计算新的centroids
    104 # boxes是所有bounding boxes的Box对象列表
    105 # n_anchors是k-means的k值
    106 # centroids是所有簇的中心
    107 # 返回值new_centroids 是计算出的新簇中心
    108 # 返回值groups是n_anchors个簇包含的boxes的列表
    109 # 返回值loss是所有box距离所属的最近的centroid的距离的和
    110 def do_kmeans(n_anchors, boxes, centroids):
    111     loss = 0
    112     groups = []
    113     new_centroids = []
    114     for i in range(n_anchors):
    115         groups.append([])
    116         new_centroids.append(Box(0, 0, 0, 0))
    117 
    118     for box in boxes:
    119         min_distance = 1
    120         group_index = 0
    121         for centroid_index, centroid in enumerate(centroids):
    122             distance = (1 - box_iou(box, centroid))
    123             if distance < min_distance:
    124                 min_distance = distance
    125                 group_index = centroid_index
    126         groups[group_index].append(box)
    127         loss += min_distance
    128         new_centroids[group_index].w += box.w
    129         new_centroids[group_index].h += box.h
    130 
    131     for i in range(n_anchors):
    132         new_centroids[i].w /= len(groups[i])
    133         new_centroids[i].h /= len(groups[i])
    134 
    135     return new_centroids, groups, loss
    136 
    137 
    138 # 计算给定bounding boxes的n_anchors数量的centroids
    139 # label_path是训练集列表文件地址
    140 # n_anchors 是anchors的数量
    141 # loss_convergence是允许的loss的最小变化值
    142 # grid_size * grid_size 是栅格数量
    143 # iterations_num是最大迭代次数
    144 # plus = 1时启用k means ++ 初始化centroids
    145 def compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus):
    146 
    147     boxes = []
    148     label_files = []
    149     f = open(label_path)
    150     for line in f:
    151         label_path = line.rstrip().replace('images', 'labels')
    152         label_path = label_path.replace('JPEGImages', 'labels')
    153         label_path = label_path.replace('.jpg', '.txt')
    154         label_path = label_path.replace('.JPEG', '.txt')
    155         label_files.append(label_path)
    156     f.close()
    157 
    158     for label_file in label_files:
    159         f = open(label_file)
    160         for line in f:
    161             temp = line.strip().split(" ")
    162             if len(temp) > 1:
    163                 boxes.append(Box(0,0, float(temp[3]), float(temp[4])))
    164                 print(temp[3],temp[4])
    165                 if float(temp[3])<0:
    166                     print(label_file)
    167     print('done')
    168     if plus:
    169         centroids = init_centroids(boxes, n_anchors)
    170     else:
    171         centroid_indices = np.random.choice(len(boxes), n_anchors)
    172         centroids = []
    173         for centroid_index in centroid_indices:
    174             centroids.append(boxes[centroid_index])
    175 
    176     # iterate k-means
    177     centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)
    178     iterations = 1
    179     while (True):
    180         centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids)
    181         iterations = iterations + 1
    182         print("loss = %f" % loss)
    183         if abs(old_loss - loss) < loss_convergence or iterations > iterations_num:
    184             break
    185         old_loss = loss
    186 
    187         for centroid in centroids:
    188             print(centroid.w * grid_size, centroid.h * grid_size)
    189 
    190     # print result
    191     for centroid in centroids:
    192         #print("k-means result:	 ", centroid.w * grid_size, ",", centroid.h * grid_size)
    193         #str('%.03f'%maxVal)
    194 
    195         print("k-means result:	 ", str('%.06f'%(centroid.w * grid_size)), ",", str('%.06f'%(centroid.h * grid_size)))
    196 
    197 
    198 label_path = "xx/train_person/data/trainImagePath.txt"
    199 n_anchors = 5
    200 loss_convergence = 1e-3
    201 grid_size = 13
    202 iterations_num = 10000000
    203 plus = 0
    204 compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus)
    View Code

    5,训练和测试

    直接命令行了。训练的时候内存够的话,再开一个命令行窗口做测试,可以一边训练,一边偶尔看看预测效果。

    6,在opencv中调用

    C++版本

      1 #include <opencv2/opencv.hpp>
      2 #include <opencv2/dnn.hpp>
      3 
      4 #include <fstream>
      5 #include <iostream>
      6 #include <algorithm>
      7 #include <cstdlib>
      8 
      9 using namespace std;
     10 using namespace cv;
     11 using namespace cv::dnn;
     12 
     13 //char* jpgpath = "";
     14 //char* cfgpath = "";
     15 //char* weightspath = "";
     16 //char* namespath = "";
     17 //Mat boxes = yoloMultiPredict(jpgpath, cfgpath, weightspath, namespath);
     18 //cout << "boxes:
    " << boxes << endl;  // class prob xmin xmax ymin ymax
     19 Mat yoloMultiPredict(char* jpgpath, char* cfgpath, char* weightspath, char* namespath)
     20 {
     21     Mat boxes = Mat::zeros(0, 6, CV_16UC1);
     22     Mat frame = imread(jpgpath);
     23 
     24     dnn::Net net = readNetFromDarknet(cfgpath, weightspath);
     25     if (net.empty())
     26     {
     27         printf("Could not load net...
    ");
     28     }
     29 
     30 
     31     // 得到类别名称
     32     if (0)
     33     {
     34         ifstream classNamesFile(namespath);
     35         vector<string> classNamesVec;
     36 
     37         if (classNamesFile.is_open())
     38         {
     39             string className = "";
     40             while (std::getline(classNamesFile, className))
     41                 classNamesVec.push_back(className);
     42         }
     43         for (int i = 0; i < classNamesVec.size(); i++)
     44             cout << i << "	" << classNamesVec[i] << endl;
     45         cout << endl;
     46     }
     47 
     48     Mat inputBlob = blobFromImage(frame, 1 / 255.F, Size(416, 416), Scalar(), true, false);
     49     net.setInput(inputBlob, "data");
     50 
     51     // 检测
     52     Mat detectionMat = net.forward("detection_out");
     53     //cout << "forward" << endl;
     54 
     55     // 输出结果
     56     for (int i = 0; i < detectionMat.rows; i++)
     57     {
     58         const int probability_index = 5;
     59         const int probability_size = detectionMat.cols - probability_index;
     60         float *prob_array_ptr = &detectionMat.at<float>(i, probability_index);
     61         size_t objectClass = max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr;
     62 
     63         float confidence = detectionMat.at<float>(i, (int)objectClass + probability_index);
     64 
     65         if (confidence > 0.24)
     66         {
     67             float x = detectionMat.at<float>(i, 0);
     68             float y = detectionMat.at<float>(i, 1);
     69             float width = detectionMat.at<float>(i, 2);
     70             float height = detectionMat.at<float>(i, 3);
     71 
     72             int xmin = static_cast<int>((x - width / 2) * frame.cols);
     73             int xmax = static_cast<int>((x + width / 2) * frame.cols);
     74 
     75             int ymin = static_cast<int>((y - height / 2) * frame.rows);
     76             int ymax = static_cast<int>((y + height / 2) * frame.rows);
     77 
     78             // clip
     79             if (xmin<0)
     80                 xmin = 0;
     81             if (xmax > frame.cols)
     82                 xmax = frame.cols - 1;
     83             if (ymin<0)
     84                 ymin = 0;
     85             if (ymax > frame.rows)
     86                 ymax = frame.rows - 1;
     87 
     88             //rectangle(frame, cvPoint(xmin, ymin), cvPoint(xmax, ymax), Scalar(0, 0, 255), 4, 1, 0);
     89             //cout << "x y w h	" << x << "	" << y << "	" << width << "	" << height << endl;
     90 
     91             // class prob xmin xmax ymin ymax
     92             Mat L = (Mat_<short>(1, 6) << (short)objectClass, (short)(confidence * 100), xmin, xmax, ymin, ymax);
     93             //cout << "L:" << L << endl;
     94             boxes.push_back(L);
     95         }
     96     }
     97 
     98     return boxes;
     99 
    100 }
    View Code

    2018年4月18日以后的opencv已经可以导入yolo v3的训练文件了。

    直接用命令行做的预测结果和opencv中导入配置文件的预测结果不一样,有的时候相差还很大,不堪重用啊。

    还是要在什么框架下训练就在什么框架下用,不然数值稳定性不能保证。。。

    python版本

    opencv的samples/dnn下的例子改造的。

     1 import cv2
     2 import numpy as np
     3 import os
     4 
     5 cwd = os.path.split(os.path.realpath(__file__))[0]
     6 
     7 def darknetPredict(jpgpathOrMat, cfgpath, wtspath):
     8     net = cv2.dnn.readNetFromDarknet(cfgpath, wtspath)
     9     
    10     confThreshold = 0.24
    11     nmsThreshold = 0.4
    12     
    13     def getOutputsNames(net):
    14         layersNames = net.getLayerNames()
    15         return [layersNames[i[0] - 1] for i in net.getUnconnectedOutLayers()]
    16     
    17     if(isinstance(jpgpathOrMat, str)):
    18         frame = cv2.imread(jpgpathOrMat)
    19     if(isinstance(jpgpathOrMat, np.ndarray)):
    20         frame = jpgpathOrMat
    21     
    22     frameHeight = frame.shape[0]
    23     frameWidth = frame.shape[1]
    24     
    25     # Create a 4D blob from a frame.
    26     inpW = 416
    27     inpH = 416
    28     blob = cv2.dnn.blobFromImage(frame, 1.0 / 255, (inpW, inpH), (0, 0, 0), swapRB=True, crop=False)
    29     
    30     net.setInput(blob)  # Run a model
    31     
    32     outs = net.forward(getOutputsNames(net))
    33     #
    34     classIds = []
    35     confidences = []
    36     boxes = []
    37     for out in outs:
    38     #    print('out:', out)
    39         for detection in out:
    40             scores = detection[5:]
    41             classId = np.argmax(scores)
    42             confidence = scores[classId]
    43             if confidence > confThreshold:
    44                 center_x = int(detection[0] * frameWidth)
    45                 center_y = int(detection[1] * frameHeight)
    46                 width    = int(detection[2] * frameWidth)
    47                 height   = int(detection[3] * frameHeight)
    48                 left     = int(center_x - width / 2)
    49                 top      = int(center_y - height / 2)
    50                 classIds.append(classId)
    51                 confidences.append(float(confidence))
    52                 boxes.append([left, top, width, height])
    53     
    54     
    55     rst_boxes = []  
    56     indices = cv2.dnn.NMSBoxes(boxes, confidences, confThreshold, nmsThreshold)
    57     for i in indices:
    58         i = i[0]
    59         box = boxes[i]
    60         # left top  w  h
    61         #    0   1  2  3
    62         # xmin ymin w  h
    63         
    64         left = box[0]
    65         top = box[1]
    66         width = box[2]
    67         height = box[3]
    68 #        print("confidences:", confidences[i])
    69         
    70         xmin, ymin, xmax, ymax = [left, top, left+width, top+height]
    71         
    72         xmin = np.clip(xmin, 0, frameWidth-1 )
    73         xmax = np.clip(xmax, 0, frameWidth-1 )
    74         ymin = np.clip(ymin, 0, frameHeight-1)
    75         ymax = np.clip(ymax, 0, frameHeight-1)
    76         
    77         line = [classIds[i], confidences[i], xmin, ymin, xmax, ymax]
    78         #       classid      prob            xmin  ymin xmax   ymax
    79         #             0         1               2     3    4      5
    80         rst_boxes.append(line)
    81         
    82     rst_boxes = np.asarray(rst_boxes)
    83     return rst_boxes
    View Code

    差不多就是这样了。

  • 相关阅读:
    mssql锁
    gridview 分页兼容BOOTSTRAP
    BOOTSTRAP前端模板
    bootstrap 简单模板
    ajax 跨域访问的解决方案
    webapi之权限验证
    webapi权限常见错误
    ajax跨域解决方案
    iis 部署webapi常见错误及解决方案
    OOM AutoMapper的简单实用
  • 原文地址:https://www.cnblogs.com/shepherd2015/p/8671646.html
Copyright © 2020-2023  润新知