以下都是基于yolo v2版本的,对于现在的v3版本,可以先clone下来,再git checkout回v2版本。
玩了三四个月的yolo后发现数值相当不稳定,yolo只能用来小打小闹了。
v2训练的权重用v3做预测,结果不一样。
我的环境是 window 10 + cuda9.0 + opencv 3.4.0 + VS2015
先在这个地方下源文件:https://github.com/AlexeyAB/darknet
下好后,先打开用文本编辑器打开 darknet.vcxproj,将两处 cuda9.1 改成 cuda9.0
还要拷贝opencv的两个dll到 x64 下
用 VS2015 打开 darknet.sln,生成一下,在 darknet-masteruilddarknetx64 下面得到一个 darknet.exe
这个时候已经可以用 训练好的模型 对 训练时的类别 做预测了
当然,要预先下好对应的 weights 文件
darknet.exe detector test cfg/combine9k.data yolo9000.cfg yolo9000.weights data/dog.jpg
先准备好训练图片!
一定要先做好文件重命名工作!不然后面想添加、修改、删减样本都很痛苦。
然后再用 windows 版的 labelImg 做标注
然后修改相关配置文件,然后就可以开始训练了。
训练前下好一个 darknet19_448.conv.23 文件
训练命令如下
darknet.exe detector train cfg/voc.data cfg/yolo-voc.cfg cfg/darknet19_448.conv.23
要准备好.data 和 .cfg 文件 以及 训练数据集
训练时如果说找不到 txt 文件,直接把 txt 文件拷贝到image文件夹下
最后会得到一个自己的 weights 文件可以用来预测自己的类别
2018年8月24日08:58:19
总结一个比较完整的流程出来,大致有以下几个步骤:
1,不停的图片采集,以及不停的对新采集的图片重命名。
因为采集到的图片名称可能是以秒命名的,也有可能是按日期命名的,或者有其他命名规范。
后续标注和训练完测试的时候,如果发现某些样本图片不理想要剔除或跳过的时候,名字不易找就比较麻烦。
2,图片标注的技巧,实际标注的xml数量是少于采集到的图片数量的,因为有些图片拍的角度或光照不理想。
3,离线数据增强
yolo本身有一些在线的数据增强,然而没有深入阅读过代码的话很难改的动,而且提供的增强手段有限。所以做了个
简单的离线数据增强。
4,训练数据预处理,修改cfg文件等。
5,训练和测试
6,在opencv中调用,C++和python两种版本。以及批量抠图。
=====================================================
1,重命名 rename.py
1 import os 2 3 def getFilenames(filepath): 4 ''' 5 得到一个文件夹下所有的文件名,不包含后缀, 忽略文件夹 6 ''' 7 filenames = [] 8 9 for file in os.listdir(filepath): 10 pt = os.path.join(filepath, file) 11 12 if( os.path.isfile(pt) ): 13 filename = os.path.splitext(file)[0] 14 filenames.append(filename) 15 return filenames 16 17 filepath = "origin_img" 18 filenames = getFilenames(filepath) 19 print(filenames) 20 21 22 # 只对文件夹中新增加的图片重命名,增量式重命名。所以要取得图片文件名中最大的数字 23 def get_max_num(filenames): 24 max_num = 0 25 26 for name in filenames: 27 if( name.isdigit() and int(name) < 10000): 28 if(int(name) > max_num): 29 max_num = int(name) 30 31 return max_num 32 33 print("max num:", get_max_num(filenames)) 34 35 renameCount = get_max_num(filenames)+1 36 37 #renameCount = 1 38 39 for file in os.listdir(filepath): 40 if( os.path.isfile( os.path.join(filepath, file) ) ): # 如果是文件 41 filename = os.path.splitext(file)[0] # 取得文件名 42 43 if( not filename.isdigit() ): # 如果文件名不是数字,则重命名 44 print("rename count 1:", renameCount) 45 os.rename(os.path.join(filepath, file), os.path.join(filepath, str('%03d'%renameCount)+".jpeg")) 46 renameCount+=1 47 48 if( filename.isdigit() and int(filename) > 10000 ): 49 print("rename count 2:", renameCount) 50 os.rename(os.path.join(filepath, file), os.path.join(filepath, str('%03d'%renameCount)+".jpeg")) 51 renameCount+=1
2,图片标注技巧
用labelimg,尽量从特征的角度考虑,把目标物体最明显的特征框进去,跟程序实际工作情况差距太大的图片
就不要标注了,如果特征不是太明显的也标了,那样本数量就要上去。
标注的xml是少于图片数量的,需要进一步将标注过的图片拿出来。
1 import os 2 import shutil 3 4 def getFilenames(filepath): 5 '''得到一个文件夹下所有的文件名,不包含后缀 6 ''' 7 filelist = os.listdir(filepath) 8 9 filenames = [] 10 11 for files in filelist: 12 filename = os.path.splitext(files)[0] 13 # print(files) 14 # print(filename) 15 filenames.append(filename) 16 return filenames 17 18 19 xmlpath = 'xml' # 这个是标注过后xml所在文件夹 20 imgpath = 'img' # 这个是标注时图片所在文件夹,图片数目多于xml 21 22 img_write_path = 'img_less' # 把标注过的图像拷贝到这个文件夹 23 24 filenames = getFilenames(xmlpath) 25 26 for i in range(len(filenames)): 27 filename = filenames[i] 28 # print(filename) 29 30 jpgpath = imgpath + "/" + str(filename) + ".jpeg" 31 # print(jpgpath) 32 33 jpg_wrt_path = img_write_path + "/" + str(filename) + ".jpg" 34 shutil.copy(jpgpath, jpg_wrt_path)
3,离线数据增强
有些数据增强要修改xml,比如水平翻转,旋转,裁剪,有些不要,比如对颜色、光照做扰动等。
本来写了水平翻转和随机裁剪,后来想想yolo把图片resize到416,随机裁剪没什么卵用。
注意数据增强的时候不要搞出一堆人为的特征让yolo去学,要合理的增强。
这里只放一个水平翻转的功能出来。
util.py
1 import numpy as np 2 import cv2 3 import matplotlib.pyplot as plt 4 import random 5 import os 6 7 8 def showimg(img): 9 channelNum = len(img.shape) 10 11 if channelNum == 3: 12 fig = plt.subplots(1),plt.imshow( cv2.cvtColor(img, cv2.COLOR_BGR2RGB) ) 13 if channelNum == 2: 14 fig = plt.subplots(1),plt.imshow( img ) 15 16 17 def scaleimg(img, scale = 1.0): 18 H, W, C = img.shape 19 size = (int(scale*W), int(scale*H)) 20 img = cv2.resize(img, size, interpolation=cv2.INTER_AREA) 21 del H, W, C, size, scale 22 return img.copy() 23 24 25 # img = rotateimg(image, angle) 26 def rotateimg(image, angle, center=None, scale=1.0): 27 # 获取图像尺寸 28 (h, w) = image.shape[:2] 29 30 # 若未指定旋转中心,则将图像中心设为旋转中心 31 if center is None: 32 center = (w / 2, h / 2) 33 34 # 执行旋转 35 M = cv2.getRotationMatrix2D(center, angle, scale) # 给的角度为正的时候,则逆时针旋转 36 rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC) 37 # rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) 38 39 return rotated # 返回旋转后的图像, angle是角度制,不是弧度制 40 41 42 ''' 43 1, 读取xml返回结果 44 输入:CLASS_NAMES元组 xml路径 45 返回:(H, W, boxes) boxes是一个二维np数组,6列分别为 46 id classid xmin xmax ymin ymax 47 0 1 2 3 4 5 48 49 2, 将boxes CLASS_NAMES H W 信息写入xml 50 输入:boxes CLASS_NAMES H W xml路径 51 输出:硬盘上的一个xml 52 53 3, 根据 img,boxes数组,class_names,画出一个图来 54 输入:img,boxes,class_names 55 ''' 56 57 import xml.etree.ElementTree as ET 58 import numpy as np 59 60 #CLASS_NAMES = ('person', 'dog') # 下标从0开始,这里可以没有顺序,最好有顺序 61 62 # id classid xmin xmax ymin ymax 63 # 0 1 2 3 4 5 64 def xml2boxes(xmlpath, CLASS_NAMES): 65 print("xmlpath:", xmlpath) 66 67 cls_to_idx = dict( zip( CLASS_NAMES , range(len(CLASS_NAMES)) )) 68 idx_to_cls = dict( zip( range(len(CLASS_NAMES)), CLASS_NAMES )) 69 70 # print(cls_to_idx) 71 # print(idx_to_cls) 72 73 annotations = ET.parse(xmlpath) 74 # 获得 HWC 75 size = annotations.find('size') 76 W = int(size.find('width').text) 77 H = int(size.find('height').text) 78 C = int(size.find('depth').text) 79 # 获得类别和具体坐标 80 bbox = list() 81 count = 1 82 for obj in annotations.iter('object'): # 提取 xml文件中的信息 83 line = [] 84 bndbox_anno = obj.find('bndbox') 85 # xmin等从 1 开始计数 86 tmp = map(int, [bndbox_anno.find('xmin').text, 87 bndbox_anno.find('xmax').text, 88 bndbox_anno.find('ymin').text, 89 bndbox_anno.find('ymax').text]) 90 tmp = list(tmp) # 1 x 4 91 92 name = obj.find('name').text.lower().strip() 93 94 line.append(count) 95 line.append(cls_to_idx[name]) 96 line.append(tmp[0]) 97 line.append(tmp[1]) 98 line.append(tmp[2]) 99 line.append(tmp[3]) 100 count = count + 1 101 # print(line) 102 bbox.append( line ) 103 104 boxes = np.stack(bbox).astype(np.int32) 105 return boxes, H, W 106 107 #boxes, H, W = xml2boxes("1.xml", CLASS_NAMES) 108 #print("boxes: ", boxes) 109 # 对只有一个类别的时候,CLASS_NAMES要在后面加一个字符串 110 # 比如 CLASS_NAMES = ("apple", "xxxx") 这是个bug,还没修 111 112 from lxml.etree import Element, SubElement, tostring 113 from xml.dom.minidom import parseString 114 115 116 ###################################################### 117 # boxes2xml_labelImg(boxes, CLASS_NAMES, H, W, xmlpath, wrtin_img_folder_name, imgName, img_fullpath) 118 def boxes2xml_labelImg(boxes, CLASS_NAMES, H, W, xmlpath, wrtin_img_folder_name, 119 imgName, img_fullpath): 120 ''' 121 这是一个labelImg可以查看的版本 122 123 这个时候要求 CLASS_NAMES 是有顺序的,和boxes里头的第二列 124 的类别id要一一对应 125 ''' 126 cls_to_idx = dict( zip( CLASS_NAMES , range(len(CLASS_NAMES)) )) 127 idx_to_cls = dict( zip( range(len(CLASS_NAMES)), CLASS_NAMES )) 128 129 130 node_annotation = Element('annotation') 131 ################################################# 132 node_folder = SubElement(node_annotation, 'folder') 133 node_filename = SubElement(node_annotation, 'filename') 134 node_path = SubElement(node_annotation, 'path') 135 136 node_source = SubElement(node_annotation, 'source') 137 node_database = SubElement(node_source, 'database') 138 139 140 node_folder.text = wrtin_img_folder_name # 这个是定死的,赋值一次就不会变了 141 node_filename.text = imgName # 图片的文件名,不包含后缀 142 node_path.text = img_fullpath # 随着文件名变化 143 144 node_database.text = "Unknown" 145 146 147 node_size = SubElement(node_annotation, 'size') 148 ################################################# 149 # node_size 150 node_width = SubElement(node_size, 'width') 151 node_height = SubElement(node_size, 'height') 152 node_depth = SubElement(node_size, 'depth') 153 154 node_width.text = str(W) 155 node_height.text = str(H) 156 node_depth.text = str(3) # 默认是彩色 157 ################################################# 158 node_segmented = SubElement(node_annotation, 'segmented') 159 node_segmented.text = "0" 160 ################################################# 161 162 # node_object 若干 要循环 163 for i in range(boxes.shape[0]): 164 node_object = SubElement(node_annotation, 'object') 165 classid = boxes[i, 1] 166 # print(idx_to_cls[classid]) 167 node_name = SubElement(node_object, 'name') 168 node_name.text = idx_to_cls[classid] 169 170 node_pose = SubElement(node_object, 'pose') 171 node_truncated = SubElement(node_object, 'truncated') 172 node_Difficult = SubElement(node_object, 'Difficult') 173 174 node_pose.text = "Unspecified" 175 node_truncated.text = "1" 176 node_Difficult.text = "0" 177 178 179 node_bndbox = SubElement(node_object, 'bndbox') 180 181 node_xmin = SubElement(node_bndbox, 'xmin') 182 node_ymin = SubElement(node_bndbox, 'ymin') 183 node_xmax = SubElement(node_bndbox, 'xmax') 184 node_ymax = SubElement(node_bndbox, 'ymax') 185 186 node_xmin.text = str(boxes[i, 2]) 187 node_xmax.text = str(boxes[i, 3]) 188 node_ymin.text = str(boxes[i, 4]) 189 node_ymax.text = str(boxes[i, 5]) 190 191 ################### 192 xml = tostring(node_annotation, pretty_print=True) #格式化显示,该换行的换行 193 dom = parseString(xml) 194 195 test_string = xml.decode('utf-8') 196 #print('test: ', test_string) 197 198 with open(xmlpath, "w") as text_file: 199 text_file.write(test_string) 200 201 ###################################################### 202 def drawboxes(imgpath, boxes, CLASS_NAMES): 203 import matplotlib.pyplot as plt 204 import matplotlib.patches as patches 205 import cv2 206 207 cls_to_idx = dict( zip( CLASS_NAMES , range(len(CLASS_NAMES)) )) 208 idx_to_cls = dict( zip( range(len(CLASS_NAMES)), CLASS_NAMES )) 209 210 if isinstance(imgpath, str): 211 img = cv2.imread(imgpath) 212 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 213 if isinstance(imgpath, np.ndarray): 214 img = imgpath 215 216 fig, ax = plt.subplots(1) 217 218 for i in range(boxes.shape[0]): 219 bndbox = list(boxes[i,:]) 220 x = bndbox[2] 221 y = bndbox[4] 222 w = bndbox[3] - bndbox[2] 223 h = bndbox[5] - bndbox[4] 224 rect = patches.Rectangle( (x,y),w,h, linewidth=1,edgecolor='yellow',facecolor='none') 225 ax.add_patch(rect) 226 name = idx_to_cls[boxes[i, 1]] 227 ax.text(x-5, y-5, name, style='italic', color='yellow', fontsize=12) 228 229 ax.imshow(img) 230 231 #drawboxes("1.jpg", boxes, CLASS_NAMES) 232 233 ################################## 234 235 236 237 def getFilenames(filepath): 238 '''得到一个文件夹下所有的文件名,不包含后缀 239 ''' 240 filelist = os.listdir(filepath) 241 242 filenames = [] 243 244 for files in filelist: 245 filename = os.path.splitext(files)[0] 246 # print(files) 247 # print(filename) 248 filenames.append(filename) 249 return filenames 250 251 252 def fliplr_boxes(boxes, W): 253 ''' 对boxes做水平翻转''' 254 boxes_copy = boxes.copy() 255 256 xmin = boxes[:, 2].copy() 257 xmax = boxes[:, 3].copy() 258 boxes_copy[:, 3] = W - 1 - xmin # 注意这里不是 2,3 是 3,2 不然xmin会大于xmax 259 boxes_copy[:, 2] = W - 1 - xmax 260 return boxes_copy
main.py
1 import os 2 import cv2 3 import numpy as np 4 import random 5 import matplotlib.pyplot as plt 6 import matplotlib.patches as patches 7 8 from util import * 9 10 img_read_path = "img_less" 11 xml_read_path = "xml" 12 13 img_write_path = "fliped_img" # 图片和xml水平翻转后的写入文件夹 14 xml_write_path = "fliped_xml" 15 16 17 18 filenames = getFilenames(xml_read_path) 19 20 CLASS_NAMES = ('person', 'aa') # 这里有个bug懒得改,一个类别时也要写两个进去 21 22 count = 201 23 24 wrtin_img_folder_name = "fliped_img" 25 26 27 for i in range(len(filenames)): 28 name = filenames[i] 29 30 imgname = img_read_path + "/" + str(name) + ".jpg" 31 img = cv2.imread(imgname) 32 33 xmlname = xml_read_path + "/" + str(name) + ".xml" 34 boxes, H, W = xml2boxes(xmlname, CLASS_NAMES) 35 # print("xmlname:", xmlname) 36 37 H,W,C = img.shape 38 ############################## 39 fliped_boxes = fliplr_boxes(boxes, W) 40 fliped_img = cv2.flip(img, 1) 41 ############################## 42 FileName = str(count) 43 44 jpgpath = img_write_path + "/" + FileName + ".jpg" 45 cv2.imwrite(jpgpath, fliped_img) 46 47 xmlpath = xml_write_path + "/" + FileName + ".xml" 48 boxes2xml_labelImg(fliped_boxes, CLASS_NAMES, H, W, xmlpath, wrtin_img_folder_name, FileName, jpgpath) 49 50 count = count + 1
4,训练数据预处理,修改cfg文件
下面这两个文件一定要执行,我从别的地方copy过来的。
trans1.py
1 import os 2 import shutil 3 4 savepath = os.getcwd() 5 6 img_path = savepath + "/img_less" # 存放训练图片的文件夹名 7 xml_path = savepath + "/xml" # 总共标注了 X 张图片 8 9 val_num = 10 #验证集数量,可修改 10 11 # 下面新建了 4 个目录 12 validateImage_path = savepath + "/validateImage"; 13 trainImage_path = savepath + "/trainImage"; 14 15 if os.path.exists(validateImage_path)== False: 16 os.mkdir(validateImage_path) 17 if os.path.exists(trainImage_path) == False: 18 os.mkdir(trainImage_path) 19 20 21 validateImageXML_path = savepath + "/validateImageXML" 22 trainImageXML_path = savepath + "/trainImageXML" 23 24 if os.path.exists(validateImageXML_path)== False: 25 os.mkdir(validateImageXML_path) 26 if os.path.exists(trainImageXML_path) == False: 27 os.mkdir(trainImageXML_path) 28 #================================================= 29 30 filelist = os.listdir(xml_path) # 以xml文件夹中的数量为标准 31 32 count = 0 33 for files in filelist: 34 filename = os.path.splitext(files)[0] # 文件名 35 36 origin_jpg_name = os.path.join(img_path, filename + '.jpg') 37 38 validateImage_jpg_name = os.path.join(validateImage_path, filename + '.jpg') 39 trainImage_jpg_name = os.path.join(trainImage_path, filename + '.jpg') 40 # print(validateImage_jpg_name) 41 42 43 if count < val_num: 44 shutil.copy(origin_jpg_name, validateImage_jpg_name); # 拷贝 validate 图片 45 46 xml_olddir = os.path.join(xml_path, filename + ".xml") 47 xml_newdir = os.path.join(validateImageXML_path, filename + ".xml") 48 49 shutil.copyfile(xml_olddir, xml_newdir) # 拷贝 validate xml文件 50 else: 51 shutil.copy(origin_jpg_name, trainImage_jpg_name) 52 53 xml_olddir = os.path.join(xml_path, filename + ".xml") 54 xml_newdir = os.path.join(trainImageXML_path, filename + ".xml") 55 56 shutil.copyfile(xml_olddir, xml_newdir) 57 58 count=count+1; 59 60 validate_txtpath = savepath + "/validateImageId.txt" 61 train_txtpath = savepath + "/trainImageId.txt" 62 63 64 def listname(path, idtxtpath): 65 filelist = os.listdir(path) # 该文件夹下所有的文件(包括文件夹) 66 f = open(idtxtpath, 'w') 67 68 for files in filelist: # 遍历所有文件 69 Olddir = os.path.join(path, files) # 原来的文件路径 70 if os.path.isdir(Olddir): # 如果是文件夹则跳过 71 continue 72 filename = os.path.splitext(files)[0] # 文件名 73 74 f.write(filename) 75 f.write(' ') 76 f.close() 77 78 79 listname(validateImage_path, validate_txtpath) 80 listname(trainImage_path, train_txtpath)
trans2.py
1 import xml.etree.ElementTree as ET 2 import pickle 3 import string 4 import os 5 import shutil 6 from os import listdir, getcwd 7 from os.path import join 8 9 sets=[('2012', 'train')] 10 11 classes = ["person"] 12 13 14 def convert(size, box): 15 dw = 1./size[0] 16 dh = 1./size[1] 17 x = (box[0] + box[1])/2.0 18 y = (box[2] + box[3])/2.0 19 w = box[1] - box[0] 20 h = box[3] - box[2] 21 x = x*dw 22 w = w*dw 23 y = y*dh 24 h = h*dh 25 return (x,y,w,h) 26 27 def convert_annotation(image_id,flag,savepath): 28 if flag == 0: 29 in_file = open(savepath+'/trainImageXML/%s.xml' % (image_id)) 30 labeltxt = savepath+'/trainImageLabelTxt'; 31 if os.path.exists(labeltxt) == False: 32 os.mkdir(labeltxt); 33 out_file = open(savepath+'/trainImageLabelTxt/%s.txt' % (image_id), 'w') 34 tree = ET.parse(in_file) 35 root = tree.getroot() 36 size = root.find('size') 37 w = int(size.find('width').text) 38 h = int(size.find('height').text) 39 elif flag == 1: 40 in_file = open(savepath+'/validateImageXML/%s.xml' % (image_id)) 41 labeltxt = savepath + '/validateImageLabelTxt'; 42 if os.path.exists(labeltxt) == False: 43 os.mkdir(labeltxt); 44 out_file = open(savepath+'/validateImageLabelTxt/%s.txt' % (image_id), 'w') 45 tree = ET.parse(in_file) 46 root = tree.getroot() 47 size = root.find('size') 48 w = int(size.find('width').text) 49 h = int(size.find('height').text) 50 51 52 53 for obj in root.iter('object'): 54 # difficult = obj.find('difficult').text 55 cls = obj.find('name').text 56 # if cls not in classes or int(difficult) == 1: 57 # continue 58 cls_id = classes.index(cls) 59 xmlbox = obj.find('bndbox') 60 b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) 61 bb = convert((w,h), b) 62 out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + ' ') 63 64 wd = getcwd() 65 66 for year, image_set in sets: 67 savepath = os.getcwd(); 68 idtxt = savepath + "/validateImageId.txt"; 69 pathtxt = savepath + "/validateImagePath.txt"; 70 71 image_ids = open(idtxt).read().strip().split() 72 list_file = open(pathtxt, 'w') 73 74 s = 'xefxbbxbf' 75 for image_id in image_ids: 76 nPos = image_id.find(s) 77 if nPos >= 0: 78 image_id = image_id[3:] 79 list_file.write('%s/validateImage/%s.jpg ' % (wd, image_id)) 80 print(image_id) 81 convert_annotation(image_id, 1, savepath) 82 list_file.close() 83 84 idtxt = savepath + "/trainImageId.txt"; 85 pathtxt = savepath + "/trainImagePath.txt" ; 86 image_ids = open(idtxt).read().strip().split() 87 list_file = open(pathtxt, 'w') 88 s = 'xefxbbxbf' 89 for image_id in image_ids: 90 nPos = image_id.find(s) 91 if nPos >= 0: 92 image_id = image_id[3:] 93 list_file.write('%s/trainImage/%s.jpg '%(wd,image_id)) 94 print(image_id) 95 convert_annotation(image_id,0,savepath) 96 list_file.close()
训练的时候,编译好了,在windows下生成了darknet.exe,ubuntu下生成了可执行文件darknet。
然后在darknet.exe同级目录新建一个文件夹叫(比如训练做行人识别)train_person,在该文件夹下新建一个
backup文件夹,新建一个data文件夹,把数据增强过的图片文件夹img_less和xml拷贝到data下面,将trans1.py
和trans2.py也放到data文件夹下,然后先执行trans1.py,再执行trans2.py。然后把生成的txt拷贝到训练图片目录下。
至于修改cfg文件,看其他的博客吧。我就提一下修改那个anchors,faster-rcnn中有9个比例确定的anchor。yolo中则是
统计了样本中的标注框然后做聚类。用了效果不错我才修改的。
1 # coding=utf-8 2 # k-means ++ for YOLOv2 anchors 3 # 通过k-means ++ 算法获取YOLOv2需要的anchors的尺寸 4 import numpy as np 5 6 # 定义Box类,描述bounding box的坐标 7 class Box(): 8 def __init__(self, x, y, w, h): 9 self.x = x 10 self.y = y 11 self.w = w 12 self.h = h 13 14 15 # 计算两个box在某个轴上的重叠部分 16 # x1是box1的中心在该轴上的坐标 17 # len1是box1在该轴上的长度 18 # x2是box2的中心在该轴上的坐标 19 # len2是box2在该轴上的长度 20 # 返回值是该轴上重叠的长度 21 def overlap(x1, len1, x2, len2): 22 len1_half = len1 / 2 23 len2_half = len2 / 2 24 25 left = max(x1 - len1_half, x2 - len2_half) 26 right = min(x1 + len1_half, x2 + len2_half) 27 28 return right - left 29 30 31 # 计算box a 和box b 的交集面积 32 # a和b都是Box类型实例 33 # 返回值area是box a 和box b 的交集面积 34 def box_intersection(a, b): 35 w = overlap(a.x, a.w, b.x, b.w) 36 h = overlap(a.y, a.h, b.y, b.h) 37 if w < 0 or h < 0: 38 return 0 39 40 area = w * h 41 return area 42 43 44 # 计算 box a 和 box b 的并集面积 45 # a和b都是Box类型实例 46 # 返回值u是box a 和box b 的并集面积 47 def box_union(a, b): 48 i = box_intersection(a, b) 49 #print a.w,a.h,b.w,b.h 50 u = a.w * a.h + b.w * b.h - i 51 return u 52 53 54 # 计算 box a 和 box b 的 iou 55 # a和b都是Box类型实例 56 # 返回值是box a 和box b 的iou 57 def box_iou(a, b): 58 #print box_union(a, b) 59 return box_intersection(a, b) / box_union(a, b) 60 61 62 # 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响 63 # boxes是所有bounding boxes的Box对象列表 64 # n_anchors是k-means的k值 65 # 返回值centroids 是初始化的n_anchors个centroid 66 def init_centroids(boxes,n_anchors): 67 centroids = [] 68 boxes_num = len(boxes) 69 70 centroid_index = np.random.choice(boxes_num, 1) 71 centroids.append(boxes[centroid_index]) 72 73 print(centroids[0].w,centroids[0].h) 74 75 for centroid_index in range(0,n_anchors-1): 76 77 sum_distance = 0 78 distance_thresh = 0 79 distance_list = [] 80 cur_sum = 0 81 82 for box in boxes: 83 min_distance = 1 84 for centroid_i, centroid in enumerate(centroids): 85 distance = (1 - box_iou(box, centroid)) 86 if distance < min_distance: 87 min_distance = distance 88 sum_distance += min_distance 89 distance_list.append(min_distance) 90 91 distance_thresh = sum_distance*np.random.random() 92 93 for i in range(0,boxes_num): 94 cur_sum += distance_list[i] 95 if cur_sum > distance_thresh: 96 centroids.append(boxes[i]) 97 print(boxes[i].w, boxes[i].h) 98 break 99 100 return centroids 101 102 103 # 进行 k-means 计算新的centroids 104 # boxes是所有bounding boxes的Box对象列表 105 # n_anchors是k-means的k值 106 # centroids是所有簇的中心 107 # 返回值new_centroids 是计算出的新簇中心 108 # 返回值groups是n_anchors个簇包含的boxes的列表 109 # 返回值loss是所有box距离所属的最近的centroid的距离的和 110 def do_kmeans(n_anchors, boxes, centroids): 111 loss = 0 112 groups = [] 113 new_centroids = [] 114 for i in range(n_anchors): 115 groups.append([]) 116 new_centroids.append(Box(0, 0, 0, 0)) 117 118 for box in boxes: 119 min_distance = 1 120 group_index = 0 121 for centroid_index, centroid in enumerate(centroids): 122 distance = (1 - box_iou(box, centroid)) 123 if distance < min_distance: 124 min_distance = distance 125 group_index = centroid_index 126 groups[group_index].append(box) 127 loss += min_distance 128 new_centroids[group_index].w += box.w 129 new_centroids[group_index].h += box.h 130 131 for i in range(n_anchors): 132 new_centroids[i].w /= len(groups[i]) 133 new_centroids[i].h /= len(groups[i]) 134 135 return new_centroids, groups, loss 136 137 138 # 计算给定bounding boxes的n_anchors数量的centroids 139 # label_path是训练集列表文件地址 140 # n_anchors 是anchors的数量 141 # loss_convergence是允许的loss的最小变化值 142 # grid_size * grid_size 是栅格数量 143 # iterations_num是最大迭代次数 144 # plus = 1时启用k means ++ 初始化centroids 145 def compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus): 146 147 boxes = [] 148 label_files = [] 149 f = open(label_path) 150 for line in f: 151 label_path = line.rstrip().replace('images', 'labels') 152 label_path = label_path.replace('JPEGImages', 'labels') 153 label_path = label_path.replace('.jpg', '.txt') 154 label_path = label_path.replace('.JPEG', '.txt') 155 label_files.append(label_path) 156 f.close() 157 158 for label_file in label_files: 159 f = open(label_file) 160 for line in f: 161 temp = line.strip().split(" ") 162 if len(temp) > 1: 163 boxes.append(Box(0,0, float(temp[3]), float(temp[4]))) 164 print(temp[3],temp[4]) 165 if float(temp[3])<0: 166 print(label_file) 167 print('done') 168 if plus: 169 centroids = init_centroids(boxes, n_anchors) 170 else: 171 centroid_indices = np.random.choice(len(boxes), n_anchors) 172 centroids = [] 173 for centroid_index in centroid_indices: 174 centroids.append(boxes[centroid_index]) 175 176 # iterate k-means 177 centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids) 178 iterations = 1 179 while (True): 180 centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids) 181 iterations = iterations + 1 182 print("loss = %f" % loss) 183 if abs(old_loss - loss) < loss_convergence or iterations > iterations_num: 184 break 185 old_loss = loss 186 187 for centroid in centroids: 188 print(centroid.w * grid_size, centroid.h * grid_size) 189 190 # print result 191 for centroid in centroids: 192 #print("k-means result: ", centroid.w * grid_size, ",", centroid.h * grid_size) 193 #str('%.03f'%maxVal) 194 195 print("k-means result: ", str('%.06f'%(centroid.w * grid_size)), ",", str('%.06f'%(centroid.h * grid_size))) 196 197 198 label_path = "xx/train_person/data/trainImagePath.txt" 199 n_anchors = 5 200 loss_convergence = 1e-3 201 grid_size = 13 202 iterations_num = 10000000 203 plus = 0 204 compute_centroids(label_path,n_anchors,loss_convergence,grid_size,iterations_num,plus)
5,训练和测试
直接命令行了。训练的时候内存够的话,再开一个命令行窗口做测试,可以一边训练,一边偶尔看看预测效果。
6,在opencv中调用
C++版本
1 #include <opencv2/opencv.hpp> 2 #include <opencv2/dnn.hpp> 3 4 #include <fstream> 5 #include <iostream> 6 #include <algorithm> 7 #include <cstdlib> 8 9 using namespace std; 10 using namespace cv; 11 using namespace cv::dnn; 12 13 //char* jpgpath = ""; 14 //char* cfgpath = ""; 15 //char* weightspath = ""; 16 //char* namespath = ""; 17 //Mat boxes = yoloMultiPredict(jpgpath, cfgpath, weightspath, namespath); 18 //cout << "boxes: " << boxes << endl; // class prob xmin xmax ymin ymax 19 Mat yoloMultiPredict(char* jpgpath, char* cfgpath, char* weightspath, char* namespath) 20 { 21 Mat boxes = Mat::zeros(0, 6, CV_16UC1); 22 Mat frame = imread(jpgpath); 23 24 dnn::Net net = readNetFromDarknet(cfgpath, weightspath); 25 if (net.empty()) 26 { 27 printf("Could not load net... "); 28 } 29 30 31 // 得到类别名称 32 if (0) 33 { 34 ifstream classNamesFile(namespath); 35 vector<string> classNamesVec; 36 37 if (classNamesFile.is_open()) 38 { 39 string className = ""; 40 while (std::getline(classNamesFile, className)) 41 classNamesVec.push_back(className); 42 } 43 for (int i = 0; i < classNamesVec.size(); i++) 44 cout << i << " " << classNamesVec[i] << endl; 45 cout << endl; 46 } 47 48 Mat inputBlob = blobFromImage(frame, 1 / 255.F, Size(416, 416), Scalar(), true, false); 49 net.setInput(inputBlob, "data"); 50 51 // 检测 52 Mat detectionMat = net.forward("detection_out"); 53 //cout << "forward" << endl; 54 55 // 输出结果 56 for (int i = 0; i < detectionMat.rows; i++) 57 { 58 const int probability_index = 5; 59 const int probability_size = detectionMat.cols - probability_index; 60 float *prob_array_ptr = &detectionMat.at<float>(i, probability_index); 61 size_t objectClass = max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr; 62 63 float confidence = detectionMat.at<float>(i, (int)objectClass + probability_index); 64 65 if (confidence > 0.24) 66 { 67 float x = detectionMat.at<float>(i, 0); 68 float y = detectionMat.at<float>(i, 1); 69 float width = detectionMat.at<float>(i, 2); 70 float height = detectionMat.at<float>(i, 3); 71 72 int xmin = static_cast<int>((x - width / 2) * frame.cols); 73 int xmax = static_cast<int>((x + width / 2) * frame.cols); 74 75 int ymin = static_cast<int>((y - height / 2) * frame.rows); 76 int ymax = static_cast<int>((y + height / 2) * frame.rows); 77 78 // clip 79 if (xmin<0) 80 xmin = 0; 81 if (xmax > frame.cols) 82 xmax = frame.cols - 1; 83 if (ymin<0) 84 ymin = 0; 85 if (ymax > frame.rows) 86 ymax = frame.rows - 1; 87 88 //rectangle(frame, cvPoint(xmin, ymin), cvPoint(xmax, ymax), Scalar(0, 0, 255), 4, 1, 0); 89 //cout << "x y w h " << x << " " << y << " " << width << " " << height << endl; 90 91 // class prob xmin xmax ymin ymax 92 Mat L = (Mat_<short>(1, 6) << (short)objectClass, (short)(confidence * 100), xmin, xmax, ymin, ymax); 93 //cout << "L:" << L << endl; 94 boxes.push_back(L); 95 } 96 } 97 98 return boxes; 99 100 }
2018年4月18日以后的opencv已经可以导入yolo v3的训练文件了。
直接用命令行做的预测结果和opencv中导入配置文件的预测结果不一样,有的时候相差还很大,不堪重用啊。
还是要在什么框架下训练就在什么框架下用,不然数值稳定性不能保证。。。
python版本
opencv的samples/dnn下的例子改造的。
1 import cv2 2 import numpy as np 3 import os 4 5 cwd = os.path.split(os.path.realpath(__file__))[0] 6 7 def darknetPredict(jpgpathOrMat, cfgpath, wtspath): 8 net = cv2.dnn.readNetFromDarknet(cfgpath, wtspath) 9 10 confThreshold = 0.24 11 nmsThreshold = 0.4 12 13 def getOutputsNames(net): 14 layersNames = net.getLayerNames() 15 return [layersNames[i[0] - 1] for i in net.getUnconnectedOutLayers()] 16 17 if(isinstance(jpgpathOrMat, str)): 18 frame = cv2.imread(jpgpathOrMat) 19 if(isinstance(jpgpathOrMat, np.ndarray)): 20 frame = jpgpathOrMat 21 22 frameHeight = frame.shape[0] 23 frameWidth = frame.shape[1] 24 25 # Create a 4D blob from a frame. 26 inpW = 416 27 inpH = 416 28 blob = cv2.dnn.blobFromImage(frame, 1.0 / 255, (inpW, inpH), (0, 0, 0), swapRB=True, crop=False) 29 30 net.setInput(blob) # Run a model 31 32 outs = net.forward(getOutputsNames(net)) 33 # 34 classIds = [] 35 confidences = [] 36 boxes = [] 37 for out in outs: 38 # print('out:', out) 39 for detection in out: 40 scores = detection[5:] 41 classId = np.argmax(scores) 42 confidence = scores[classId] 43 if confidence > confThreshold: 44 center_x = int(detection[0] * frameWidth) 45 center_y = int(detection[1] * frameHeight) 46 width = int(detection[2] * frameWidth) 47 height = int(detection[3] * frameHeight) 48 left = int(center_x - width / 2) 49 top = int(center_y - height / 2) 50 classIds.append(classId) 51 confidences.append(float(confidence)) 52 boxes.append([left, top, width, height]) 53 54 55 rst_boxes = [] 56 indices = cv2.dnn.NMSBoxes(boxes, confidences, confThreshold, nmsThreshold) 57 for i in indices: 58 i = i[0] 59 box = boxes[i] 60 # left top w h 61 # 0 1 2 3 62 # xmin ymin w h 63 64 left = box[0] 65 top = box[1] 66 width = box[2] 67 height = box[3] 68 # print("confidences:", confidences[i]) 69 70 xmin, ymin, xmax, ymax = [left, top, left+width, top+height] 71 72 xmin = np.clip(xmin, 0, frameWidth-1 ) 73 xmax = np.clip(xmax, 0, frameWidth-1 ) 74 ymin = np.clip(ymin, 0, frameHeight-1) 75 ymax = np.clip(ymax, 0, frameHeight-1) 76 77 line = [classIds[i], confidences[i], xmin, ymin, xmax, ymax] 78 # classid prob xmin ymin xmax ymax 79 # 0 1 2 3 4 5 80 rst_boxes.append(line) 81 82 rst_boxes = np.asarray(rst_boxes) 83 return rst_boxes
差不多就是这样了。