• 【目标检测】NMS非极大值抑制代码示例


     1 import numpy as np
     2 
     3 def non_max_suppress(predicts_dict, threshold=0.2):
     4     """
     5     implement non-maximum supression on predict bounding boxes.
     6     Args:
     7         predicts_dict: {"stick": [[x1, y1, x2, y2, scores1], [...]]}.
     8         threshhold: iou threshold
     9     Return:
    10         predicts_dict processed by non-maximum suppression
    11     """
    12     for object_name, bbox in predicts_dict.items(): #对每一个类别的目标分别进行NMS
    13         bbox_array = np.array(bbox, dtype=np.float)
    14  
    15         ## 获取当前目标类别下所有矩形框(bounding box,下面简称bbx)的坐标和confidence,并计算所有bbx的面积
    16         x1, y1, x2, y2, scores = bbox_array[:,0], bbox_array[:,1], bbox_array[:,2], bbox_array[:,3], bbox_array[:,4]
    17         areas = (x2-x1+1) * (y2-y1+1)
    18         #print("areas shape = ", areas.shape)
    19  
    20         ## 对当前类别下所有的bbx的confidence进行从高到低排序(order保存索引信息)
    21         order = scores.argsort()[::-1]
    22         print("类别%s的order = "%object_name, order)
    23         keep = [] #用来存放最终保留的bbx的索引信息
    24         k = 1 
    25         ## 依次从按confidence从高到低遍历bbx,移除所有与该矩形框的IOU值大于threshold的矩形框
    26         while order.size > 0:
    27             print('第%d次遍历'%(k))
    28             i = order[0]
    29             keep.append(i) #保留当前最大confidence对应的bbx索引
    30  
    31             ## 获取所有与当前bbx的交集对应的左上角和右下角坐标,并计算IOU(注意这里是同时计算一个bbx与其他所有bbx的IOU)
    32             xx1 = np.maximum(x1[i], x1[order[1:]]) #当order.size=1时,下面的计算结果都为np.array([]),不影响最终结果
    33             yy1 = np.maximum(y1[i], y1[order[1:]])
    34             xx2 = np.minimum(x2[i], x2[order[1:]])
    35             yy2 = np.minimum(y2[i], y2[order[1:]])
    36             inter = np.maximum(0.0, xx2-xx1+1) * np.maximum(0.0, yy2-yy1+1)
    37             iou = inter/(areas[i]+areas[order[1:]]-inter)
    38             print("iou =", iou)
    39  
    40             print(np.where(iou<=threshold)) #输出没有被移除的bbx索引(相对于iou向量的索引)
    41             indexs = np.where(iou<=threshold)[0] + 1 #获取保留下来的索引(因为没有计算与自身的IOU,所以索引相差1,需要加上)
    42             print("indexs = ", indexs)
    43             order = order[indexs] #更新保留下来的索引, ( array([0, 1, 2]),)
    44             print("order = ", order)
    45             k+=1
    46         bbox = bbox_array[keep]
    47         predicts_dict[object_name] = bbox.tolist()
    48         predicts_dict = predicts_dict
    49     return predicts_dict
    50     
    51 if __name__ == "__main__":
    52     #predicts_dict={"cup":[[894, 354, 63, 60, 0.64], [648, 386, 72, 59, 0.91],[772, 233, 30, 43, 0.66], [723, 246, 50, 41, 0.89]]}
    53     predicts_dict={"cup":[[647, 385, 789, 501, 0.64], [648, 386, 792, 504, 0.91],
    54                             [772, 233, 832, 319, 0.66], [767, 224, 828, 309, 0.78], [723, 246, 823, 328, 0.89]],
    55                    "person":[[647, 385, 789, 501, 0.64], [648, 386, 792, 504, 0.91],
    56                             [772, 233, 832, 319, 0.66], [767, 224, 828, 309, 0.78], [723, 246, 823, 328, 0.89]]}
    57     predicts_dict=non_max_suppress(predicts_dict, threshold=0.2)
    58     print(predicts_dict)

    运行结果:

     1 类别cup的order =  [1 4 3 2 0]
     2 第1次遍历
     3 iou = [0.         0.         0.         0.94050474]
     4 (array([0, 1, 2]),)
     5 indexs =  [1 2 3]
     6 order =  [4 3 2]
     7 第2次遍历
     8 iou = [0.36237211 0.39097744]
     9 (array([], dtype=int64),)
    10 indexs =  []
    11 order =  []
    12 类别person的order =  [1 4 3 2 0]
    13 第1次遍历
    14 iou = [0.         0.         0.         0.94050474]
    15 (array([0, 1, 2]),)
    16 indexs =  [1 2 3]
    17 order =  [4 3 2]
    18 第2次遍历
    19 iou = [0.36237211 0.39097744]
    20 (array([], dtype=int64),)
    21 indexs =  []
    22 order =  []
    23 {'cup': [[648.0, 386.0, 792.0, 504.0, 0.91], [723.0, 246.0, 823.0, 328.0, 0.89]], 'person': [[648.0, 386.0, 792.0, 504.0, 0.91], [723.0, 246.0, 823.0, 328.0, 0.89]]}

    参考博客:

     https://blog.csdn.net/m0_37605642/article/details/98358864

  • 相关阅读:
    老友记实战,17上
    老友记实战,9下
    老友记实战,5下
    公共样式base.css
    单选框radio总结(获取值、设置默认选中值、样式)
    js tab切换
    HTTP状态码100、200、300、400、500、600的含义
    微信小程序事件绑定
    微信小程序获取手机验证码
    js滚动到指定位置导航栏固定顶部
  • 原文地址:https://www.cnblogs.com/DJames23/p/12508453.html
Copyright © 2020-2023  润新知