• refinedet一些pytorch,python语法学习


    官方链接:
    https://github.com/luuuyi/RefineDet.PyTorch

    product

    for k, f in enumerate([10, 8, 5, 3]):
        print("f:=====",f)
        for i, j in product(range(f), repeat=2):
            print(i,j)
    
    f:===== 3
    0 0
    0 1
    0 2
    1 0
    1 1
    1 2
    2 0
    2 1
    2 2
    f:===== 5
    0 0
    0 1
    0 2
    0 3
    0 4
    1 0
    1 1
    1 2
    1 3
    1 4
    2 0
    2 1
    2 2
    2 3
    2 4
    3 0
    3 1
    3 2
    3 3
    3 4
    4 0
    4 1
    4 2
    4 3
    4 4
    

    解析voc xml

    根据代码,写的测试样例:
    例如xml里面内容如下:voc格式

    <annotation>
       <folder>VOC2007</folder>
       <filename>seat_190530_623.jpg</filename>
       <source>
           <database>The VOC2007 Database</database>
           <annotation>PASCAL VOC2007</annotation>
           <image>flickr</image>
           <flickrid>329145082</flickrid>
       </source>
       <owner>&gt;
           <flickrid>hiromori2</flickrid>
           <name>Hiroyuki Mori</name>
       </owner>&gt;
       <size>
           <width>1024</width>
           <height>768</height>
           <depth>3</depth>
       </size>
       <segmented>0</segmented>
       <object>
           <name>zuoyianquandai</name>
           <pose>Unspecified</pose>
           <truncated>0</truncated>
           <difficult>0</difficult>
           <bndbox>
               <xmin>476</xmin>
               <ymin>276</ymin>
               <xmax>562</xmax>
               <ymax>372</ymax>
           </bndbox>
       </object>
       <object>
           <name>zuoyianquandai</name>
           <pose>Unspecified</pose>
           <truncated>0</truncated>
           <difficult>0</difficult>
           <bndbox>
               <xmin>440</xmin>
               <ymin>271</ymin>
               <xmax>506</xmax>
               <ymax>372</ymax>
           </bndbox>
       </object>
       <object>
           <name>zuoyianquandai</name>
           <pose>Unspecified</pose>
           <truncated>0</truncated>
           <difficult>0</difficult>
           <bndbox>
               <xmin>622</xmin>
               <ymin>616</ymin>
               <xmax>726</xmax>
               <ymax>717</ymax>
           </bndbox>
       </object>
       <object>
           <name>zuoyianquandai</name>
           <pose>Unspecified</pose>
           <truncated>0</truncated>
           <difficult>0</difficult>
           <bndbox>
               <xmin>348</xmin>
               <ymin>598</ymin>
               <xmax>456</xmax>
               <ymax>720</ymax>
           </bndbox>
       </object>
       <object>
           <name>seat</name>
           <pose>Unspecified</pose>
           <truncated>0</truncated>
           <difficult>0</difficult>
           <bndbox>
               <xmin>270</xmin>
               <ymin>15</ymin>
               <xmax>825</xmax>
               <ymax>367</ymax>
           </bndbox>
       </object>
       <object>
           <name>seat</name>
           <pose>Unspecified</pose>
           <truncated>0</truncated>
           <difficult>0</difficult>
           <bndbox>
               <xmin>72</xmin>
               <ymin>66</ymin>
               <xmax>492</xmax>
               <ymax>683</ymax>
           </bndbox>
       </object>
       <object>
           <name>seat</name>
           <pose>Unspecified</pose>
           <truncated>0</truncated>
           <difficult>0</difficult>
           <bndbox>
               <xmin>612</xmin>
               <ymin>0</ymin>
               <xmax>1024</xmax>
               <ymax>704</ymax>
           </bndbox>
       </object>
    </annotation>
    

    代码如下:

    import os
    import xml.etree.ElementTree as ET
    
    
    root_dir = "/data_2/project_2021/refinedet/pytorch_refinedet/data/VOCdevkit/VOC2007/Annotations/"
    
    list_xml = os.listdir(root_dir)
    for cnt, name in enumerate(list_xml):
        print(cnt,name)
        path_xml = root_dir + name
        target = ET.parse(path_xml).getroot()
    
        res = []
        for obj in target.iter('object'):
            difficult = int(obj.find('difficult').text) == 1
            if difficult:
                continue
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')
    
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(float((bbox.find(pt).text)) + 0.5) - 1
                # scale height or width
                # cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            # label_idx = self.class_to_ind[name]
            # bndbox.append(label_idx)
    
            # label_idx = self.class_to_ind[name]
            bndbox.append(name)
            res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
    
        #a = 0
    

    res里面值如下:
    <class 'list'>: [[475, 275, 561, 371, 'zuoyianquandai'], [439, 270, 505, 371, 'zuoyianquandai'], [621, 615, 725, 716, 'zuoyianquandai'], [347, 597, 455, 719, 'zuoyianquandai'], [269, 14, 824, 366, 'seat'], [71, 65, 491, 682, 'seat'], [611, -1, 1023, 703, 'seat']]

    np.hstack() np.vstack() target = np.hstack((boxes, np.expand_dims(labels, axis=1)))

    np.vstack():在竖直方向上堆叠
    np.hstack():在水平方向上平铺

    import numpy as np
    
    arr1=np.array([1,2,3])
    arr2=np.array([4,5,6])
    print(np.vstack)
    print (np.vstack((arr1,arr2)))
    print(np.hstack)
    print (np.hstack((arr1,arr2)))
    

    打印如下:
    <function vstack at 0x7ff6e333d0e0>
    [[1 2 3]
    [4 5 6]]
    <function hstack at 0x7ff6e333d290>
    [1 2 3 4 5 6]
    Process finished with exit code 0

    target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
    boxes[5,4]
    label:[5] -- >np.expand_dims(labels, axis=1) -->>>>>[5,1]
    ==>target[5,5]

    a[::-1]

    a = [1,2,3,4,5]
    
    b = a[::-1]
    
    print(a)
    print(b)
    #[1, 2, 3, 4, 5]
    #[5, 4, 3, 2, 1]
    

    zip 例如:for (x, l, c) in zip(sources, self.arm_loc, self.arm_conf):

    sources, self.arm_loc, self.arm_conf都是长度相同的列表,sources是数据,arm_loc和arm_conf是conv2d之类的操作方法

    for (x, l, c) in zip(sources, self.arm_loc, self.arm_conf):
        arm_loc.append(l(x).permute(0, 2, 3, 1).contiguous())
        arm_conf.append(c(x).permute(0, 2, 3, 1).contiguous())
    

    torch.max() | tensor([[6, 3, 0, ..., 6, 0, 2]]) | best_truth_overlap, best_truth_idx = overlap.max(0, keepdim=True)

    torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
    按维度dim 返回最大值,并且返回索引。
    torch.max(a,0)返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)。返回的最大值和索引各是一个tensor,一起构成元组(Tensor, LongTensor)
    torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)

    import torch
    a = torch.rand(3,5)
    print(a)
    print("========================")
    print("a.max(0)")
    print(a.max(0))
    print("========================")
    print("a.max(1)")
    print(a.max(1))
    
    Connected to pydev debugger (build 182.4505.26)
    tensor([[0.2695, 0.3127, 0.5122, 0.4659, 0.8935],
            [0.8419, 0.1534, 0.4232, 0.7792, 0.4795],
            [0.9919, 0.9686, 0.1972, 0.2406, 0.4112]])
    ========================
    a.max(0)
    torch.return_types.max(
    values=tensor([0.9919, 0.9686, 0.5122, 0.7792, 0.8935]),
    indices=tensor([2, 2, 0, 1, 0]))
    ========================
    a.max(1)
    torch.return_types.max(
    values=tensor([0.8935, 0.8419, 0.9919]),
    indices=tensor([4, 0, 0]))
    

    这里我有点儿迷糊,max(0),max(1)分的不清,0代表列?1代表行?
    原本shape[3,5]的tensor经过max(0)就得到[1,5]
    在refinedet里面,下面的代码:

    overlap = torch.rand(7,6375)
    best_prior_overlap, best_prior_idx = overlap.max(1, keepdim=True)
    best_truth_overlap, best_truth_idx = overlap.max(0, keepdim=True)
    

    overlap的含义是7个groundtruth与6375个prior的交并比,所以best_prior_overlap的维度知道是什么样子的吗?代表的含义又是啥?
    best_prior_overlap的shape[7,1]
    best_prior_idx的shape[7,1],取值范围是[0,6375)
    每个groundtruth与哪个prior的iou最大,最大的prior是多少。

    best_truth_overlap的shape是[1,6375]
    best_truth_idx的shape是[1,6375],取值范围是[0,7)
    每个prior与哪个groundtruth的iou最大

    index_fill_(dim,index,val) |||| best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior

    x = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    index = torch.LongTensor([0, 2])
    x.index_fill_(1, index, 8)#([[8., 2., 8.],
                               # [8., 5., 8.],
                               # [8., 8., 8.]])
    

    refinedet代码中:

    best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # ensure best prior
    aaa = best_truth_overlap[best_prior_idx[0].type(torch.LongTensor)] ##==2?   yes!
    

    这个就有点儿意思了,首先best_truth_overlap里面存放的都是交并比0到1的值,best_truth_overlap是竖直的[6375]找的最大,即每个prior与groundtruth找的最大值。
    best_prior_idx的shape[7,1],取值范围是[0,6375)。best_prior_idx是横向找到的最大值的位置。
    代码best_truth_overlap.index_fill_(0, best_prior_idx, 2) 意思就是在best_prior_idx的位置上把best_truth_overlap对应位置赋值2。感觉就是best_truth_overlap[best_prior_idx]=2类似的操作。
    总的来说好像就是代码注释的这句# ensure best prior

    好记性不如烂键盘---点滴、积累、进步!
  • 相关阅读:
    【leetcode】腾讯精选练习 50 题(更新中)
    将博客搬至CSDN
    【笔记】linux基础(更新中)
    【寒窑赋】写在百篇博客之后
    【笔记】Vim
    【笔记】Git(更新中)
    【笔记】Java基础教程学习(更新中)
    【面试题】Java核心技术三十六讲(更新中)
    【leetcode】shell和sql题目思路汇总(更新中)
    【笔记】MySQL基础及高级特性(更新中)
  • 原文地址:https://www.cnblogs.com/yanghailin/p/14378830.html
Copyright © 2020-2023  润新知