• pytorch实现yolov3(4) 非极大值抑制nms


    上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box.
    理解了yolov3的输出的格式及每一个位置的含义,并不难理解源码.我在阅读源码的过程中主要的困难在于对pytorch不熟悉,所以在这篇文章里,关于其中涉及的一些pytorch中的函数的用法我都已经用加粗标示了并且给出了相应的链接,测试代码等.

    obj score threshold

    我们设置一个obj score thershold,超过这个值的才认为是有效的.

        conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2)
        prediction = prediction*conf_mask
    

    prediction是1*boxnum*boxattr
    prediction[:,:,4]是1*boxnum 元素值为boxattr的index=4的那个值.

    torch中的Tensor index和numpy是类似的,参看下列代码输出

    import torch
    x = torch.Tensor(1,3,10)    # Create an un-initialized Tensor of size 2x3
    print(x)
    print(x.shape)                  # Print out the Tensor
    
    y = x[:,:,4]
    print(y)
    print(y.shape)
    
    z = x[:,:,4:6]
    print(z)
    print(z.shape)
    
    print((y>0.5).float().unsqueeze(2))
    
    #### 输出如下
    tensor([[[2.5226e-18, 1.6898e-04, 1.0413e-11, 7.7198e-10, 1.0549e-08,
              4.0516e-11, 1.0681e-05, 2.9575e-18, 6.7333e+22, 1.7591e+22],
             [1.7184e+25, 4.3222e+27, 6.1972e-04, 7.2443e+22, 1.7728e+28,
              7.0367e+22, 5.9018e-10, 2.6540e-09, 1.2972e-11, 5.3370e-08],
             [2.7001e-06, 2.6801e-09, 4.1292e-05, 2.1511e+23, 3.2770e-09,
              2.5125e-18, 7.7052e+31, 1.9447e+31, 5.0207e+28, 1.1492e-38]]])
    torch.Size([1, 3, 10])
    tensor([[1.0549e-08, 1.7728e+28, 3.2770e-09]])
    torch.Size([1, 3])
    tensor([[[1.0549e-08, 4.0516e-11],
             [1.7728e+28, 7.0367e+22],
             [3.2770e-09, 2.5125e-18]]])
    torch.Size([1, 3, 2])
    
    tensor([[[0.],
             [0.],
             [0.]]])
    

    Squeeze and unsqueeze 降低维度,升高维度.

    t = torch.ones(2,1,2,1) # Size 2x1x2x1
    r = torch.squeeze(t)     # Size 2x2
    r = torch.squeeze(t, 1)  # Squeeze dimension 1: Size 2x2x1
    
    # Un-squeeze a dimension
    x = torch.Tensor([1, 2, 3])
    r = torch.unsqueeze(x, 0)       # Size: 1x3  表示在第0个维度添加1维
    r = torch.unsqueeze(x, 1)       # Size: 3x1  表示在第1个维度添加1维
    
    

    这样prediction中objscore<threshold的已经变成了0.

    nms

    tensor.new() 创建一个和原有tensor的dtype一致的新tensor https://stackoverflow.com/questions/49263588/pytorch-beginner-tensor-new-method

        #得到box坐标(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)
        box_corner = prediction.new(prediction.shape)
        box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
        box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
        box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) 
        box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
        prediction[:,:,:4] = box_corner[:,:,:4]
    

    原始的prediction中boxattr存放的是x,y,w,h,...,不方便我们处理,我们将其转换成(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)

    接下来我们挨个处理每一张图片对应的feature map.

        batch_size = prediction.size(0)
        write = False
    
        for ind in range(batch_size):
            #image_pred.shape=boxnum*boxattr
            image_pred = prediction[ind]          #image Tensor  box_num*box_attr
            #confidence threshholding 
            #NMS
            #返回每一行的最大值,及最大值所在的列.
            max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1)
            #升级成和image_pred同样的维度
            max_conf = max_conf.float().unsqueeze(1)
            max_conf_score = max_conf_score.float().unsqueeze(1)
            seq = (image_pred[:,:5], max_conf, max_conf_score)
            
            #沿着列的方向拼接. 现在image_pred变成boxnum*7
            image_pred = torch.cat(seq, 1)
            
            
    
    

    这里涉及到torch.max的用法,参见https://blog.csdn.net/Z_lbj/article/details/79766690
    torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
    按维度dim 返回最大值.可以这么记忆,沿着第dim维度比较.torch.max(0)即沿着行的方向比较,即得到每列的最大值.
    假设input是二维矩阵,即行*列,行是第0维,列是第一维.

    • torch.max(a,0) 返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)
    • torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
    c=torch.Tensor([[1,2,3],[6,5,4]])
    print(c)
    a,b=torch.max(c,1)
    print(a)
    print(b)
    
    ##输出如下:
    tensor([[1., 2., 3.],
            [6., 5., 4.]])
    tensor([3., 6.])
    tensor([2, 0])
    

    torch.cat用法,参见https://pytorch.org/docs/stable/torch.html

    torch.cat(tensors, dim=0, out=None) → Tensor
    >>> x = torch.randn(2, 3)
    >>> x
    tensor([[ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497]])
    >>> torch.cat((x, x, x), 0)
    tensor([[ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497],
            [ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497],
            [ 0.6580, -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497]])
    >>> torch.cat((x, x, x), 1)
    tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
             -1.0969, -0.4614],
            [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
             -0.5790,  0.1497]])
    

    接下来我们只处理obj_score非0的数据(obj_score<obj_threshold转变为0)

            non_zero_ind =  (torch.nonzero(image_pred[:,4]))
            try:
                image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7)
            except:
                continue
    
            #For PyTorch 0.4 compatibility
            #Since the above code with not raise exception for no detection 
            #as scalars are supported in PyTorch 0.4
            if image_pred_.shape[0] == 0:
                continue 
    
    

    ok,接下来我们对每一种class做nms.
    首先取到我们有哪些类别

            #Get the various classes detected in the image
            img_classes = unique(image_pred_[:,-1])  # -1 index holds the class index
    

    然后依次对每一种类别做处理

    for cls in img_classes:
                #perform NMS
    
            
                #get the detections with one particular class
                #取出当前class为当前class且class prob!=0的行
                cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)
                class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
                image_pred_class = image_pred_[class_mask_ind].view(-1,7)
                
                #sort the detections such that the entry with the maximum objectness
                #confidence is at the top
                #按照obj score从高到低做排序
                conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1]
                image_pred_class = image_pred_class[conf_sort_index]
                idx = image_pred_class.size(0)   #Number of detections
                
                for i in range(idx):
                    #Get the IOUs of all boxes that come after the one we are looking at 
                    #in the loop
                    try:
                        #计算第i个和其后每一行的的iou
                        ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:])
                    except ValueError:
                        break
                
                    except IndexError:
                        break
                
                    #Zero out all the detections that have IoU > treshhold
                    #把与第i行iou>nms_conf的认为是同一个目标的box,将其转成0
                    iou_mask = (ious < nms_conf).float().unsqueeze(1)
                    image_pred_class[i+1:] *= iou_mask       
                
                    #把iou>nms_conf的移除掉
                    non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()
                    image_pred_class = image_pred_class[non_zero_ind].view(-1,7)
                    
                batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind)      #Repeat the batch_id for as many detections of the class cls in the image
                seq = batch_ind, image_pred_class
    

    其中计算iou的代码如下,不多解释了.iou=交叠面积/总面积

    def bbox_iou(box1, box2):
        """
        Returns the IoU of two bounding boxes 
        
        
        """
        #Get the coordinates of bounding boxes
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]
        
        #get the corrdinates of the intersection rectangle
        inter_rect_x1 =  torch.max(b1_x1, b2_x1)
        inter_rect_y1 =  torch.max(b1_y1, b2_y1)
        inter_rect_x2 =  torch.min(b1_x2, b2_x2)
        inter_rect_y2 =  torch.min(b1_y2, b2_y2)
        
        #Intersection area
        inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0)
    
        #Union Area
        b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)
        b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)
        
        iou = inter_area / (b1_area + b2_area - inter_area)
        
        return iou
    
    

    关于nms可以看下https://blog.csdn.net/shuzfan/article/details/52711706

    tensor index操作用法如下:

    image_pred_ = torch.Tensor([[1,2,3,4,9],[5,6,7,8,9]])
    #print(image_pred_[:,-1] == 9)
    has_9 = (image_pred_[:,-1] == 9)
    print(has_9)
    
    ###执行顺序是(image_pred_[:,-1] == 9).float().unsqueeze(1) 再做tensor乘法
    cls_mask = image_pred_*(image_pred_[:,-1] == 9).float().unsqueeze(1)
    print(cls_mask)
    class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
    image_pred_class = image_pred_[class_mask_ind]
    
    输出如下:
    tensor([1, 1], dtype=torch.uint8)
    tensor([[1., 2., 3., 4., 9.],
            [5., 6., 7., 8., 9.]])
    

    torch.sort用法如下:

    d=torch.Tensor([[1,2,3],[6,5,4]])
    e=d[:,2]
    print(e)
    print(torch.sort(e))
    
    输出
    tensor([3., 4.])
    
    torch.return_types.sort(
    values=tensor([3., 4.]),
    indices=tensor([0, 1]))
    

    总结一下我们做nms的流程
    每一个image,会预测出N个detetction信息,包括4+1+C(4个坐标信息,1个obj score以及C个class probability)

    • 首先过滤掉obj_score < confidence的行
    • 每一行只取class probability最高的作为预测出来的类别
    • 将所有的预测按照obj_score从大到小排序
    • 循环每一种类别,开始做nms
      • 比较第一个box与其后所有box的iou,删除iou>threshold的box,即剔除所有相似box
      • 比较下一个box与其后所有box的iou,删除所有与该box相似的box
      • 不断重复上述过程,直至不再有相似box
      • 至此,实现了当前处理的类别的多个box均是独一无二的box.

    write_results最终的返回值是一个n*8的tensor,其中8是(batch_index,4个坐标,1个objscore,1个class prob,一个class index)

    def write_results(prediction, confidence, num_classes, nms_conf = 0.4):
        print("prediction.shape=",prediction.shape)
    
        #将obj_score < confidence的行置为0
        conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2)
        prediction = prediction*conf_mask
    
        #得到box坐标(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)
        box_corner = prediction.new(prediction.shape)
        box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
        box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
        box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) 
        box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
        #修改prediction第三个维度的前四列
        prediction[:,:,:4] = box_corner[:,:,:4]
    
        batch_size = prediction.size(0)
        write = False
    
        for ind in range(batch_size):
            #image_pred.shape=boxnum*boxattr
            image_pred = prediction[ind]          #image Tensor
            #confidence threshholding 
            #NMS
    
            ##取出每一行的class score最大的一个
            max_conf_score,max_conf = torch.max(image_pred[:,5:5+ num_classes], 1)
            max_conf = max_conf.float().unsqueeze(1)
            max_conf_score = max_conf_score.float().unsqueeze(1)
            seq = (image_pred[:,:5], max_conf_score, max_conf)
            image_pred = torch.cat(seq, 1) #现在变成7列,分别为左上角x,左上角y,右下角x,右下角y,obj score,最大probabilty,相应的class index
            print(image_pred.shape)
    
            non_zero_ind =  (torch.nonzero(image_pred[:,4]))
            try:
                image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7)
            except:
                continue
    
            #For PyTorch 0.4 compatibility
            #Since the above code with not raise exception for no detection 
            #as scalars are supported in PyTorch 0.4
            if image_pred_.shape[0] == 0:
                continue 
    
            #Get the various classes detected in the image
            img_classes = unique(image_pred_[:,-1])  # -1 index holds the class index
            
            
            for cls in img_classes:
                #perform NMS
    
                #get the detections with one particular class
                #取出当前class为当前class且class prob!=0的行
                cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)
                class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
                image_pred_class = image_pred_[class_mask_ind].view(-1,7)
                
                #sort the detections such that the entry with the maximum objectness
                #confidence is at the top
                #按照obj score从高到低做排序
                conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1]
                image_pred_class = image_pred_class[conf_sort_index]
                idx = image_pred_class.size(0)   #Number of detections
                
                for i in range(idx):
                    #Get the IOUs of all boxes that come after the one we are looking at 
                    #in the loop
                    try:
                        #计算第i个和其后每一行的的iou
                        ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:])
                    except ValueError:
                        break
                
                    except IndexError:
                        break
                
                    #Zero out all the detections that have IoU > treshhold
                    #把与第i行iou>nms_conf的认为是同一个目标的box,将其转成0
                    iou_mask = (ious < nms_conf).float().unsqueeze(1)
                    image_pred_class[i+1:] *= iou_mask       
                
                    #把iou>nms_conf的移除掉
                    non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()
                    image_pred_class = image_pred_class[non_zero_ind].view(-1,7)
                    
                batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind)      #Repeat the batch_id for as many detections of the class cls in the image
                seq = batch_ind, image_pred_class
                
                if not write:
                    output = torch.cat(seq,1)  #沿着列方向,shape 1*8
                    write = True
                else:
                    out = torch.cat(seq,1)
                    output = torch.cat((output,out)) #沿着行方向 shape n*8
    
        try:
            return output
        except:
            return 0
    
  • 相关阅读:
    Session的异常
    struts2中把action中的值传递到jsp页面的例子
    struts2中怎么把action中的值传递到jsp页面
    struts2理解
    Struts2工作原理
    第十五章 String讲解
    十六进制转十进制
    数据库综合系列 之 触发器
    android PopupWindow实现从底部弹出或滑出选择菜单或窗口
    kettle内存溢出
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/11151453.html
Copyright © 2020-2023  润新知