1 import numpy as np 2 3 def non_max_suppress(predicts_dict, threshold=0.2): 4 """ 5 implement non-maximum supression on predict bounding boxes. 6 Args: 7 predicts_dict: {"stick": [[x1, y1, x2, y2, scores1], [...]]}. 8 threshhold: iou threshold 9 Return: 10 predicts_dict processed by non-maximum suppression 11 """ 12 for object_name, bbox in predicts_dict.items(): #对每一个类别的目标分别进行NMS 13 bbox_array = np.array(bbox, dtype=np.float) 14 15 ## 获取当前目标类别下所有矩形框(bounding box,下面简称bbx)的坐标和confidence,并计算所有bbx的面积 16 x1, y1, x2, y2, scores = bbox_array[:,0], bbox_array[:,1], bbox_array[:,2], bbox_array[:,3], bbox_array[:,4] 17 areas = (x2-x1+1) * (y2-y1+1) 18 #print("areas shape = ", areas.shape) 19 20 ## 对当前类别下所有的bbx的confidence进行从高到低排序(order保存索引信息) 21 order = scores.argsort()[::-1] 22 print("类别%s的order = "%object_name, order) 23 keep = [] #用来存放最终保留的bbx的索引信息 24 k = 1 25 ## 依次从按confidence从高到低遍历bbx,移除所有与该矩形框的IOU值大于threshold的矩形框 26 while order.size > 0: 27 print('第%d次遍历'%(k)) 28 i = order[0] 29 keep.append(i) #保留当前最大confidence对应的bbx索引 30 31 ## 获取所有与当前bbx的交集对应的左上角和右下角坐标,并计算IOU(注意这里是同时计算一个bbx与其他所有bbx的IOU) 32 xx1 = np.maximum(x1[i], x1[order[1:]]) #当order.size=1时,下面的计算结果都为np.array([]),不影响最终结果 33 yy1 = np.maximum(y1[i], y1[order[1:]]) 34 xx2 = np.minimum(x2[i], x2[order[1:]]) 35 yy2 = np.minimum(y2[i], y2[order[1:]]) 36 inter = np.maximum(0.0, xx2-xx1+1) * np.maximum(0.0, yy2-yy1+1) 37 iou = inter/(areas[i]+areas[order[1:]]-inter) 38 print("iou =", iou) 39 40 print(np.where(iou<=threshold)) #输出没有被移除的bbx索引(相对于iou向量的索引) 41 indexs = np.where(iou<=threshold)[0] + 1 #获取保留下来的索引(因为没有计算与自身的IOU,所以索引相差1,需要加上) 42 print("indexs = ", indexs) 43 order = order[indexs] #更新保留下来的索引, ( array([0, 1, 2]),) 44 print("order = ", order) 45 k+=1 46 bbox = bbox_array[keep] 47 predicts_dict[object_name] = bbox.tolist() 48 predicts_dict = predicts_dict 49 return predicts_dict 50 51 if __name__ == "__main__": 52 #predicts_dict={"cup":[[894, 354, 63, 60, 0.64], [648, 386, 72, 59, 0.91],[772, 233, 30, 43, 0.66], [723, 246, 50, 41, 0.89]]} 53 predicts_dict={"cup":[[647, 385, 789, 501, 0.64], [648, 386, 792, 504, 0.91], 54 [772, 233, 832, 319, 0.66], [767, 224, 828, 309, 0.78], [723, 246, 823, 328, 0.89]], 55 "person":[[647, 385, 789, 501, 0.64], [648, 386, 792, 504, 0.91], 56 [772, 233, 832, 319, 0.66], [767, 224, 828, 309, 0.78], [723, 246, 823, 328, 0.89]]} 57 predicts_dict=non_max_suppress(predicts_dict, threshold=0.2) 58 print(predicts_dict)
运行结果:
1 类别cup的order = [1 4 3 2 0] 2 第1次遍历 3 iou = [0. 0. 0. 0.94050474] 4 (array([0, 1, 2]),) 5 indexs = [1 2 3] 6 order = [4 3 2] 7 第2次遍历 8 iou = [0.36237211 0.39097744] 9 (array([], dtype=int64),) 10 indexs = [] 11 order = [] 12 类别person的order = [1 4 3 2 0] 13 第1次遍历 14 iou = [0. 0. 0. 0.94050474] 15 (array([0, 1, 2]),) 16 indexs = [1 2 3] 17 order = [4 3 2] 18 第2次遍历 19 iou = [0.36237211 0.39097744] 20 (array([], dtype=int64),) 21 indexs = [] 22 order = [] 23 {'cup': [[648.0, 386.0, 792.0, 504.0, 0.91], [723.0, 246.0, 823.0, 328.0, 0.89]], 'person': [[648.0, 386.0, 792.0, 504.0, 0.91], [723.0, 246.0, 823.0, 328.0, 0.89]]}
参考博客: