• refinedet tensorRT实现



    本文github链接,欢迎star

    1.准备工作

    首先,要弄清楚自己在做什么,然后一步步的去实现它,在此过程中,要不断验证自己的每个步骤是否正确。验证正确了再往下继续走,不正确就要去排查哪里有问题。
    现在是需要把pytorch refinedet转tensorrt。而且是一步步的搭建网络实现。
    pytorch refinedet https://github.com/luuuyi/RefineDet.PyTorch
    tensorrt https://github.com/wang-xinyu/tensorrtx
    我也是刚接触tensorrt,做一件事情之前找个现成例子,跑一跑,再学习实现细节。
    我跑通了tensorrtx里面yolov3的demo,这期间花费了很长时间。因为tensorrt实现的是U版的yolov3,https://github.com/ultralytics/yolov3。
    这个仓库代码是商用的,实现的质量很高,并且代码里面用了各种奇淫技巧。我首先是把这个纯pytorch的yoloV3代码跑通,并结合论文弄懂了原理。然后再去看tensorrt的复现。
    tensorrtx这个仓库里面包含了很多流行网络的实现,都是用tensorrt api一步步搭的。
    感谢作者开源了这么优秀的代码,并且一群优秀的人不断贡献着开源。
    总结一个大体流程:
    (1)要了解网络的每层的实现,数据流流向,nchw大小,卷积核大小。
    (2)在pytorch工程里面把权重导出,用16进制导致wts文件,字典格式,层名对应权重。后续tensorRT里面根据对应层名字来索引对应权重。
    (3)搭一层网络加载对应权重
    (4)先生成engine,再推理。
    具体细节后面讲解
    给出一些链接:
    安装tensorrt教程:
    https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/install.md
    一步步搭建tensorrt详细的教程:
    https://github.com/wang-xinyu/tensorrtx/blob/master/tutorials/from_pytorch_to_trt_stepbystep_hrnet.md

    2.wts权重文件生成

    import torch
    import torch.nn as nn
    import struct
    from models.refinedet import build_refinedet
    
    num_classes = 25
    path_model = "/data_2/pytorch_refinedet/2021/20210308.pth"
    path_save_wts = "./refinedet0312.wts"
    input_size = 320
    
    net = build_refinedet('test', input_size, num_classes)  # initialize net
    net.load_state_dict(torch.load(path_model))
    net.eval()
    
    f = open(path_save_wts, 'w')
    f.write('{}
    '.format(len(net.state_dict().keys())))
    for k, v in net.state_dict().items():
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f',float(vv)).hex())
        f.write('
    ')
    print("success generate wts!")
    

    这个基本是用的仓库的代码。加载pth文件,然后保存字典权重wts到本地。
    在tensorrt里面解析,解析成类似与字典格式,层名对应权重:

    std::map<std::string, Weights> loadWeights(const std::string file) {
        std::cout << "Loading weights: " << file << std::endl;
        std::map<std::string, Weights> weightMap;
    
        // Open weights file
        std::ifstream input(file);
        assert(input.is_open() && "Unable to load weight file.");
    
        // Read number of weight blobs
        int32_t count;
        input >> count;
        assert(count > 0 && "Invalid weight map file.");
    
        while (count--)
        {
            Weights wt{DataType::kFLOAT, nullptr, 0};
            uint32_t size;
    
            // Read name and type of blob
            std::string name;
            input >> name >> std::dec >> size;
            wt.type = DataType::kFLOAT;
    
            // Load blob
            uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
            for (uint32_t x = 0, y = size; x < y; ++x)
            {
                input >> std::hex >> val[x];
            }
            wt.values = val;
            
            wt.count = size;
            weightMap[name] = wt;
        }
    
        return weightMap;
    }
    

    3.tensorrt 网络搭建

    由于原生的pytorch工程网络实现复杂。如下网络的构建:

    def build_refinedet(phase, size=320, num_classes=21):
        if phase != "test" and phase != "train":
            print("ERROR: Phase: " + phase + " not recognized")
            return
        if size != 320 and size != 512:
            print("ERROR: You specified size " + repr(size) + ". However, " +
                  "currently only RefineDet320 and RefineDet512 is supported!")
            return
        base_ = vgg(base[str(size)], 3)
        extras_ = add_extras(extras[str(size)], size, 1024)
        ARM_ = arm_multibox(base_, extras_, mbox[str(size)])
        ODM_ = odm_multibox(base_, extras_, mbox[str(size)], num_classes)
        TCB_ = add_tcb(tcb[str(size)])
        return RefineDet(phase, size, base_, extras_, ARM_, ODM_, TCB_, num_classes)
    

    网络由各个组件来完成,每个组件里面各种循环添加卷积、池化层等。这种方式不利于我们一层层的搭建和调试代码。所以,我就用最笨的方法,就是一层层的写卷积、池化同样实现作者的这个网络。
    当然,也是需要把refinedet这个工程看的熟透。所以,我自己按照官方实现的搭的网络就是像下面这样的:

    class refinedet_my(nn.Module):  # SfSNet = PS-Net in SfSNet_deploy.prototxt
        def __init__(self):
            # C64
            super(refinedet_my, self).__init__()
            self.num_classes = 25
            self.conv0 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
            self.relu1 = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
            self.relu3 = nn.ReLU(inplace=True)
            self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
    
            self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.relu6 = nn.ReLU(inplace=True)
            self.conv7 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
            self.relu8 = nn.ReLU(inplace=True)
            self.maxpool9 = nn.MaxPool2d(kernel_size=2, stride=2)
    
            self.conv10 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.relu11 = nn.ReLU(inplace=True)
            self.conv12 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
            self.relu13 = nn.ReLU(inplace=True)
            self.conv14 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
            self.relu15 = nn.ReLU(inplace=True)
            self.maxpool16 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
    
            self.conv17 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
            self.relu18 = nn.ReLU(inplace=True)
            self.conv19 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.relu20 = nn.ReLU(inplace=True)
            self.conv21 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.relu22 = nn.ReLU(inplace=True)
            self.maxpool23 = nn.MaxPool2d(kernel_size=2, stride=2)
    
            self.conv24 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.relu25 = nn.ReLU(inplace=True)
            self.conv26 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.relu27 = nn.ReLU(inplace=True)
            self.conv28 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
            self.relu29 = nn.ReLU(inplace=True)
            self.maxpool30 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    
            self.conv31 = nn.Conv2d(512, 1024, kernel_size=3, padding=3, dilation=3)
            self.relu32 = nn.ReLU(inplace=True)
    
            self.conv33 = nn.Conv2d(1024, 1024, kernel_size=1)
            self.relu34 = nn.ReLU(inplace=True)
    
            self.extras0 = nn.Conv2d(1024, 256, kernel_size=1)
            self.relu_e0 = nn.ReLU(inplace=True)
            self.extras1 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
            self.relu_e1 = nn.ReLU(inplace=True)
    
            self.conv4_3_L2Norm = L2Norm(512, 10)
            self.conv5_3_L2Norm = L2Norm(512, 8)
    
            self.arm_loc_0 = nn.Conv2d(512, 12, kernel_size=3, padding=1)
            self.arm_loc_1 = nn.Conv2d(512, 12, kernel_size=3, padding=1)
            self.arm_loc_2 = nn.Conv2d(1024, 12, kernel_size=3, padding=1)
            self.arm_loc_3 = nn.Conv2d(512, 12, kernel_size=3, padding=1)
    
            self.arm_conf_0 = nn.Conv2d(512, 6, kernel_size=3, padding=1)
            self.arm_conf_1 = nn.Conv2d(512, 6, kernel_size=3, padding=1)
            self.arm_conf_2 = nn.Conv2d(1024, 6, kernel_size=3, padding=1)
            self.arm_conf_3 = nn.Conv2d(512, 6, kernel_size=3, padding=1)
    
            self.tcb0_9 = nn.Conv2d(512, 256, 3, padding=1)
            self.tcb0_10 = nn.ReLU(inplace=True)
            self.tcb0_11 = nn.Conv2d(256, 256, 3, padding=1)
    
            self.tcb0_6 = nn.Conv2d(1024, 256, 3, padding=1)
            self.tcb0_7 = nn.ReLU(inplace=True)
            self.tcb0_8 = nn.Conv2d(256, 256, 3, padding=1)
    
            self.tcb0_3 = nn.Conv2d(512, 256, 3, padding=1)
            self.tcb0_4 = nn.ReLU(inplace=True)
            self.tcb0_5 = nn.Conv2d(256, 256, 3, padding=1)
    
            self.tcb0_0 = nn.Conv2d(512, 256, 3, padding=1)
            self.tcb0_1 = nn.ReLU(inplace=True)
            self.tcb0_2 = nn.Conv2d(256, 256, 3, padding=1)
    
            self.tcb2_0 = nn.ReLU(inplace=True)
            self.tcb2_1 = nn.Conv2d(256, 256, 3, padding=1)
            self.tcb2_2 = nn.ReLU(inplace=True)
    
            self.tcb2_3 = nn.ReLU(inplace=True)
            self.tcb2_4 = nn.Conv2d(256, 256, 3, padding=1)
            self.tcb2_5 = nn.ReLU(inplace=True)
    
            self.tcb2_6 = nn.ReLU(inplace=True)
            self.tcb2_7 = nn.Conv2d(256, 256, 3, padding=1)
            self.tcb2_8 = nn.ReLU(inplace=True)
    
            self.tcb2_9 = nn.ReLU(inplace=True)
            self.tcb2_10 = nn.Conv2d(256, 256, 3, padding=1)
            self.tcb2_11 = nn.ReLU(inplace=True)
    
            self.tcb1_2 = nn.ConvTranspose2d(256, 256, 2, 2)
            self.tcb1_1 = nn.ConvTranspose2d(256, 256, 2, 2)
            self.tcb1_0 = nn.ConvTranspose2d(256, 256, 2, 2)
    
            self.odm_loc_0 = nn.Conv2d(256, 12, kernel_size=3, padding=1)
            self.odm_loc_1 = nn.Conv2d(256, 12, kernel_size=3, padding=1)
            self.odm_loc_2 = nn.Conv2d(256, 12, kernel_size=3, padding=1)
            self.odm_loc_3 = nn.Conv2d(256, 12, kernel_size=3, padding=1)
    
            self.odm_conf_0 = nn.Conv2d(256, 3 * self.num_classes, kernel_size=3, padding=1)
            self.odm_conf_1 = nn.Conv2d(256, 3 * self.num_classes, kernel_size=3, padding=1)
            self.odm_conf_2 = nn.Conv2d(256, 3 * self.num_classes, kernel_size=3, padding=1)
            self.odm_conf_3 = nn.Conv2d(256, 3 * self.num_classes, kernel_size=3, padding=1)
    
            self.softmax = nn.Softmax(dim=-1)
    
        def forward(self, inputs):
            inputs = inputs.cuda()
            sources = list()
            tcb_source = list()
    
            x = self.relu1(self.conv0(inputs))
            x = self.relu3(self.conv2(x))
            x = self.maxpool4(x)
    
            x = self.relu6(self.conv5(x))
            x = self.relu8(self.conv7(x))
            x = self.maxpool9(x)
    
            x = self.relu11(self.conv10(x))
            x = self.relu13(self.conv12(x))
            x = self.relu15(self.conv14(x))
            x = self.maxpool16(x)
    
            x = self.relu18(self.conv17(x))
            x = self.relu20(self.conv19(x))
            x = self.relu22(self.conv21(x))
            out_conv4_3_L2Norm = x.clone()
            out_conv4_3_L2Norm = self.conv4_3_L2Norm(out_conv4_3_L2Norm)  ####s_0
            sources.append(out_conv4_3_L2Norm)
            x = self.maxpool23(x)
    
            x = self.relu25(self.conv24(x))
            x = self.relu27(self.conv26(x))
            x = self.relu29(self.conv28(x))
            out_conv5_3_L2Norm = x.clone()
            out_conv5_3_L2Norm = self.conv5_3_L2Norm(out_conv5_3_L2Norm) ####s_1
            sources.append(out_conv5_3_L2Norm)
            x = self.maxpool30(x)
    
            x = self.relu32(self.conv31(x))
            x = self.relu34(self.conv33(x))
            sources.append(x)
           return x
    

    有一层搭一层,这样还可以随时调试,在forward函数里面想看哪一层输出就可以直接打断点看tensor值。

    当然一开始不会一下子写这么多的,一开始都是pytorch里面写一层,然后tensorrt里面搭一层,然后对比两边结果一不一样,一样再继续往下走。
    这里有几个问题:
    1.需要保证两边送到网络的input一致才能验证网络输出是否一致
    2.需要两边加载的权重是一致的才能才能验证网络输出是否一致
    对于第一个问题,保证两边送到网络的input一致
    pytorch读一张图代码,这张图我本地resize到320大小

        img = cv2.imread(path_img).astype(np.float32)
        img = img[:, :, (2, 1, 0)] ## bgr --> rgb
        img = img / 255.0
        img_2 = torch.from_numpy(img).permute(2, 0, 1) ## hwc --> chw
        bb0 = img_2.unsqueeze(0)
        out = net(img_2.unsqueeze(0))
    

    然后对应的tensorrt里面读图的代码

            float data[3 * INPUT_H * INPUT_W];
            pr_img = cv::imread(path_img);
            for (int i = 0; i < INPUT_H * INPUT_W; i++) {
                data[i] = (float)(pr_img.at<cv::Vec3b>(i)[2]) * 1.0 / 255.0;
                data[i + INPUT_H * INPUT_W] = (float)(pr_img.at<cv::Vec3b>(i)[1]) * 1.0 / 255.0;
                data[i + 2 * INPUT_H * INPUT_W] = (float)(pr_img.at<cv::Vec3b>(i)[0]) * 1.0  / 255.0;
            }
    

    这两者是等价的。(i)[2]代表的R通道,opencv读取的图片是BGR的格式。
    对于第二个问题,保证两边加载的权重是一致
    pytorch加载模型权重接口:
    checkpoint = torch.load(path_model, map_location=torch.device('cpu'))
    net.load_state_dict(checkpoint, strict=False)
    这里面的参数strict,设置为true的时候就是需要加载的pth和网络每一层一致,不一致就会报错退出
    设置为false的时候就是需要加载的pth和网络层有一致的就加载,不一致的就不加载。
    checkpoint = torch.load(path_model, map_location=torch.device('cpu'))
    这个pth通过torch.load之后就是相当于一个字典的格式,名字对应权重。
    通过下面的代码,可以打印出网络需要的层名和权重大小。同时也打印出load的模型的层名和shape大小。

        net = refinedet_my()
        # net.eval()
        index = 0
        print("=" * 50)
        for name, param in list(net.named_parameters()):
            print(str(index) + ':', name, param.size())
            index += 1
        print("=" * 50)
    
        for k, v in net.state_dict().items():
            print(k,"  shape::",v.shape)
    
        print("@" * 50)
    
        checkpoint = torch.load(path_model, map_location=torch.device('cpu'))
    
        import collections
        new_state_dict = collections.OrderedDict()
    
        print("--------load pth name----------------------")
        for k, v in checkpoint.items():
            print(k,"  shape==",v.shape)
    

    打印如下:

    ==================================================
    0: conv0.weight torch.Size([64, 3, 3, 3])
    1: conv0.bias torch.Size([64])
    2: conv2.weight torch.Size([64, 64, 3, 3])
    3: conv2.bias torch.Size([64])
    4: conv5.weight torch.Size([128, 64, 3, 3])
    5: conv5.bias torch.Size([128])
    6: conv7.weight torch.Size([128, 128, 3, 3])
    7: conv7.bias torch.Size([128])
    8: conv10.weight torch.Size([256, 128, 3, 3])
    9: conv10.bias torch.Size([256])
    10: conv12.weight torch.Size([256, 256, 3, 3])
    11: conv12.bias torch.Size([256])
    12: conv14.weight torch.Size([256, 256, 3, 3])
    13: conv14.bias torch.Size([256])
    14: conv17.weight torch.Size([512, 256, 3, 3])
    15: conv17.bias torch.Size([512])
    16: conv19.weight torch.Size([512, 512, 3, 3])
    17: conv19.bias torch.Size([512])
    18: conv21.weight torch.Size([512, 512, 3, 3])
    19: conv21.bias torch.Size([512])
    20: conv24.weight torch.Size([512, 512, 3, 3])
    21: conv24.bias torch.Size([512])
    22: conv26.weight torch.Size([512, 512, 3, 3])
    23: conv26.bias torch.Size([512])
    24: conv28.weight torch.Size([512, 512, 3, 3])
    25: conv28.bias torch.Size([512])
    26: conv31.weight torch.Size([1024, 512, 3, 3])
    27: conv31.bias torch.Size([1024])
    28: conv33.weight torch.Size([1024, 1024, 1, 1])
    29: conv33.bias torch.Size([1024])
    30: extras0.weight torch.Size([256, 1024, 1, 1])
    31: extras0.bias torch.Size([256])
    32: extras1.weight torch.Size([512, 256, 3, 3])
    33: extras1.bias torch.Size([512])
    34: conv4_3_L2Norm.weight torch.Size([512])
    35: conv5_3_L2Norm.weight torch.Size([512])
    36: arm_loc_0.weight torch.Size([12, 512, 3, 3])
    37: arm_loc_0.bias torch.Size([12])
    38: arm_loc_1.weight torch.Size([12, 512, 3, 3])
    39: arm_loc_1.bias torch.Size([12])
    40: arm_loc_2.weight torch.Size([12, 1024, 3, 3])
    41: arm_loc_2.bias torch.Size([12])
    42: arm_loc_3.weight torch.Size([12, 512, 3, 3])
    43: arm_loc_3.bias torch.Size([12])
    44: arm_conf_0.weight torch.Size([6, 512, 3, 3])
    45: arm_conf_0.bias torch.Size([6])
    46: arm_conf_1.weight torch.Size([6, 512, 3, 3])
    47: arm_conf_1.bias torch.Size([6])
    48: arm_conf_2.weight torch.Size([6, 1024, 3, 3])
    49: arm_conf_2.bias torch.Size([6])
    50: arm_conf_3.weight torch.Size([6, 512, 3, 3])
    51: arm_conf_3.bias torch.Size([6])
    52: tcb0_9.weight torch.Size([256, 512, 3, 3])
    53: tcb0_9.bias torch.Size([256])
    54: tcb0_11.weight torch.Size([256, 256, 3, 3])
    55: tcb0_11.bias torch.Size([256])
    56: tcb0_6.weight torch.Size([256, 1024, 3, 3])
    57: tcb0_6.bias torch.Size([256])
    58: tcb0_8.weight torch.Size([256, 256, 3, 3])
    59: tcb0_8.bias torch.Size([256])
    60: tcb0_3.weight torch.Size([256, 512, 3, 3])
    61: tcb0_3.bias torch.Size([256])
    62: tcb0_5.weight torch.Size([256, 256, 3, 3])
    63: tcb0_5.bias torch.Size([256])
    64: tcb0_0.weight torch.Size([256, 512, 3, 3])
    65: tcb0_0.bias torch.Size([256])
    66: tcb0_2.weight torch.Size([256, 256, 3, 3])
    67: tcb0_2.bias torch.Size([256])
    68: tcb2_1.weight torch.Size([256, 256, 3, 3])
    69: tcb2_1.bias torch.Size([256])
    70: tcb2_4.weight torch.Size([256, 256, 3, 3])
    71: tcb2_4.bias torch.Size([256])
    72: tcb2_7.weight torch.Size([256, 256, 3, 3])
    73: tcb2_7.bias torch.Size([256])
    74: tcb2_10.weight torch.Size([256, 256, 3, 3])
    75: tcb2_10.bias torch.Size([256])
    76: tcb1_2.weight torch.Size([256, 256, 2, 2])
    77: tcb1_2.bias torch.Size([256])
    78: tcb1_1.weight torch.Size([256, 256, 2, 2])
    79: tcb1_1.bias torch.Size([256])
    80: tcb1_0.weight torch.Size([256, 256, 2, 2])
    81: tcb1_0.bias torch.Size([256])
    82: odm_loc_0.weight torch.Size([12, 256, 3, 3])
    83: odm_loc_0.bias torch.Size([12])
    84: odm_loc_1.weight torch.Size([12, 256, 3, 3])
    85: odm_loc_1.bias torch.Size([12])
    86: odm_loc_2.weight torch.Size([12, 256, 3, 3])
    87: odm_loc_2.bias torch.Size([12])
    88: odm_loc_3.weight torch.Size([12, 256, 3, 3])
    89: odm_loc_3.bias torch.Size([12])
    90: odm_conf_0.weight torch.Size([75, 256, 3, 3])
    91: odm_conf_0.bias torch.Size([75])
    92: odm_conf_1.weight torch.Size([75, 256, 3, 3])
    93: odm_conf_1.bias torch.Size([75])
    94: odm_conf_2.weight torch.Size([75, 256, 3, 3])
    95: odm_conf_2.bias torch.Size([75])
    96: odm_conf_3.weight torch.Size([75, 256, 3, 3])
    97: odm_conf_3.bias torch.Size([75])
    ==================================================
    conv0.weight   shape:: torch.Size([64, 3, 3, 3])
    conv0.bias   shape:: torch.Size([64])
    conv2.weight   shape:: torch.Size([64, 64, 3, 3])
    conv2.bias   shape:: torch.Size([64])
    conv5.weight   shape:: torch.Size([128, 64, 3, 3])
    conv5.bias   shape:: torch.Size([128])
    conv7.weight   shape:: torch.Size([128, 128, 3, 3])
    conv7.bias   shape:: torch.Size([128])
    conv10.weight   shape:: torch.Size([256, 128, 3, 3])
    conv10.bias   shape:: torch.Size([256])
    conv12.weight   shape:: torch.Size([256, 256, 3, 3])
    conv12.bias   shape:: torch.Size([256])
    conv14.weight   shape:: torch.Size([256, 256, 3, 3])
    conv14.bias   shape:: torch.Size([256])
    conv17.weight   shape:: torch.Size([512, 256, 3, 3])
    conv17.bias   shape:: torch.Size([512])
    conv19.weight   shape:: torch.Size([512, 512, 3, 3])
    conv19.bias   shape:: torch.Size([512])
    conv21.weight   shape:: torch.Size([512, 512, 3, 3])
    conv21.bias   shape:: torch.Size([512])
    conv24.weight   shape:: torch.Size([512, 512, 3, 3])
    conv24.bias   shape:: torch.Size([512])
    conv26.weight   shape:: torch.Size([512, 512, 3, 3])
    conv26.bias   shape:: torch.Size([512])
    conv28.weight   shape:: torch.Size([512, 512, 3, 3])
    conv28.bias   shape:: torch.Size([512])
    conv31.weight   shape:: torch.Size([1024, 512, 3, 3])
    conv31.bias   shape:: torch.Size([1024])
    conv33.weight   shape:: torch.Size([1024, 1024, 1, 1])
    conv33.bias   shape:: torch.Size([1024])
    extras0.weight   shape:: torch.Size([256, 1024, 1, 1])
    extras0.bias   shape:: torch.Size([256])
    extras1.weight   shape:: torch.Size([512, 256, 3, 3])
    extras1.bias   shape:: torch.Size([512])
    conv4_3_L2Norm.weight   shape:: torch.Size([512])
    conv5_3_L2Norm.weight   shape:: torch.Size([512])
    arm_loc_0.weight   shape:: torch.Size([12, 512, 3, 3])
    arm_loc_0.bias   shape:: torch.Size([12])
    arm_loc_1.weight   shape:: torch.Size([12, 512, 3, 3])
    arm_loc_1.bias   shape:: torch.Size([12])
    arm_loc_2.weight   shape:: torch.Size([12, 1024, 3, 3])
    arm_loc_2.bias   shape:: torch.Size([12])
    arm_loc_3.weight   shape:: torch.Size([12, 512, 3, 3])
    arm_loc_3.bias   shape:: torch.Size([12])
    arm_conf_0.weight   shape:: torch.Size([6, 512, 3, 3])
    arm_conf_0.bias   shape:: torch.Size([6])
    arm_conf_1.weight   shape:: torch.Size([6, 512, 3, 3])
    arm_conf_1.bias   shape:: torch.Size([6])
    arm_conf_2.weight   shape:: torch.Size([6, 1024, 3, 3])
    arm_conf_2.bias   shape:: torch.Size([6])
    arm_conf_3.weight   shape:: torch.Size([6, 512, 3, 3])
    arm_conf_3.bias   shape:: torch.Size([6])
    tcb0_9.weight   shape:: torch.Size([256, 512, 3, 3])
    tcb0_9.bias   shape:: torch.Size([256])
    tcb0_11.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb0_11.bias   shape:: torch.Size([256])
    tcb0_6.weight   shape:: torch.Size([256, 1024, 3, 3])
    tcb0_6.bias   shape:: torch.Size([256])
    tcb0_8.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb0_8.bias   shape:: torch.Size([256])
    tcb0_3.weight   shape:: torch.Size([256, 512, 3, 3])
    tcb0_3.bias   shape:: torch.Size([256])
    tcb0_5.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb0_5.bias   shape:: torch.Size([256])
    tcb0_0.weight   shape:: torch.Size([256, 512, 3, 3])
    tcb0_0.bias   shape:: torch.Size([256])
    tcb0_2.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb0_2.bias   shape:: torch.Size([256])
    tcb2_1.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb2_1.bias   shape:: torch.Size([256])
    tcb2_4.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb2_4.bias   shape:: torch.Size([256])
    tcb2_7.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb2_7.bias   shape:: torch.Size([256])
    tcb2_10.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb2_10.bias   shape:: torch.Size([256])
    tcb1_2.weight   shape:: torch.Size([256, 256, 2, 2])
    tcb1_2.bias   shape:: torch.Size([256])
    tcb1_1.weight   shape:: torch.Size([256, 256, 2, 2])
    tcb1_1.bias   shape:: torch.Size([256])
    tcb1_0.weight   shape:: torch.Size([256, 256, 2, 2])
    tcb1_0.bias   shape:: torch.Size([256])
    odm_loc_0.weight   shape:: torch.Size([12, 256, 3, 3])
    odm_loc_0.bias   shape:: torch.Size([12])
    odm_loc_1.weight   shape:: torch.Size([12, 256, 3, 3])
    odm_loc_1.bias   shape:: torch.Size([12])
    odm_loc_2.weight   shape:: torch.Size([12, 256, 3, 3])
    odm_loc_2.bias   shape:: torch.Size([12])
    odm_loc_3.weight   shape:: torch.Size([12, 256, 3, 3])
    odm_loc_3.bias   shape:: torch.Size([12])
    odm_conf_0.weight   shape:: torch.Size([75, 256, 3, 3])
    odm_conf_0.bias   shape:: torch.Size([75])
    odm_conf_1.weight   shape:: torch.Size([75, 256, 3, 3])
    odm_conf_1.bias   shape:: torch.Size([75])
    odm_conf_2.weight   shape:: torch.Size([75, 256, 3, 3])
    odm_conf_2.bias   shape:: torch.Size([75])
    odm_conf_3.weight   shape:: torch.Size([75, 256, 3, 3])
    odm_conf_3.bias   shape:: torch.Size([75])
    @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    --------load pth name----------------------
    vgg.0.weight   shape== torch.Size([64, 3, 3, 3])
    vgg.0.bias   shape== torch.Size([64])
    vgg.2.weight   shape== torch.Size([64, 64, 3, 3])
    vgg.2.bias   shape== torch.Size([64])
    vgg.5.weight   shape== torch.Size([128, 64, 3, 3])
    vgg.5.bias   shape== torch.Size([128])
    vgg.7.weight   shape== torch.Size([128, 128, 3, 3])
    vgg.7.bias   shape== torch.Size([128])
    vgg.10.weight   shape== torch.Size([256, 128, 3, 3])
    vgg.10.bias   shape== torch.Size([256])
    vgg.12.weight   shape== torch.Size([256, 256, 3, 3])
    vgg.12.bias   shape== torch.Size([256])
    vgg.14.weight   shape== torch.Size([256, 256, 3, 3])
    vgg.14.bias   shape== torch.Size([256])
    vgg.17.weight   shape== torch.Size([512, 256, 3, 3])
    vgg.17.bias   shape== torch.Size([512])
    vgg.19.weight   shape== torch.Size([512, 512, 3, 3])
    vgg.19.bias   shape== torch.Size([512])
    vgg.21.weight   shape== torch.Size([512, 512, 3, 3])
    vgg.21.bias   shape== torch.Size([512])
    vgg.24.weight   shape== torch.Size([512, 512, 3, 3])
    vgg.24.bias   shape== torch.Size([512])
    vgg.26.weight   shape== torch.Size([512, 512, 3, 3])
    vgg.26.bias   shape== torch.Size([512])
    vgg.28.weight   shape== torch.Size([512, 512, 3, 3])
    vgg.28.bias   shape== torch.Size([512])
    vgg.31.weight   shape== torch.Size([1024, 512, 3, 3])
    vgg.31.bias   shape== torch.Size([1024])
    vgg.33.weight   shape== torch.Size([1024, 1024, 1, 1])
    vgg.33.bias   shape== torch.Size([1024])
    conv4_3_L2Norm.weight   shape== torch.Size([512])
    conv5_3_L2Norm.weight   shape== torch.Size([512])
    extras.0.weight   shape== torch.Size([256, 1024, 1, 1])
    extras.0.bias   shape== torch.Size([256])
    extras.1.weight   shape== torch.Size([512, 256, 3, 3])
    extras.1.bias   shape== torch.Size([512])
    arm_loc.0.weight   shape== torch.Size([12, 512, 3, 3])
    arm_loc.0.bias   shape== torch.Size([12])
    arm_loc.1.weight   shape== torch.Size([12, 512, 3, 3])
    arm_loc.1.bias   shape== torch.Size([12])
    arm_loc.2.weight   shape== torch.Size([12, 1024, 3, 3])
    arm_loc.2.bias   shape== torch.Size([12])
    arm_loc.3.weight   shape== torch.Size([12, 512, 3, 3])
    arm_loc.3.bias   shape== torch.Size([12])
    arm_conf.0.weight   shape== torch.Size([6, 512, 3, 3])
    arm_conf.0.bias   shape== torch.Size([6])
    arm_conf.1.weight   shape== torch.Size([6, 512, 3, 3])
    arm_conf.1.bias   shape== torch.Size([6])
    arm_conf.2.weight   shape== torch.Size([6, 1024, 3, 3])
    arm_conf.2.bias   shape== torch.Size([6])
    arm_conf.3.weight   shape== torch.Size([6, 512, 3, 3])
    arm_conf.3.bias   shape== torch.Size([6])
    odm_loc.0.weight   shape== torch.Size([12, 256, 3, 3])
    odm_loc.0.bias   shape== torch.Size([12])
    odm_loc.1.weight   shape== torch.Size([12, 256, 3, 3])
    odm_loc.1.bias   shape== torch.Size([12])
    odm_loc.2.weight   shape== torch.Size([12, 256, 3, 3])
    odm_loc.2.bias   shape== torch.Size([12])
    odm_loc.3.weight   shape== torch.Size([12, 256, 3, 3])
    odm_loc.3.bias   shape== torch.Size([12])
    odm_conf.0.weight   shape== torch.Size([75, 256, 3, 3])
    odm_conf.0.bias   shape== torch.Size([75])
    odm_conf.1.weight   shape== torch.Size([75, 256, 3, 3])
    odm_conf.1.bias   shape== torch.Size([75])
    odm_conf.2.weight   shape== torch.Size([75, 256, 3, 3])
    odm_conf.2.bias   shape== torch.Size([75])
    odm_conf.3.weight   shape== torch.Size([75, 256, 3, 3])
    odm_conf.3.bias   shape== torch.Size([75])
    tcb0.0.weight   shape== torch.Size([256, 512, 3, 3])
    tcb0.0.bias   shape== torch.Size([256])
    tcb0.2.weight   shape== torch.Size([256, 256, 3, 3])
    tcb0.2.bias   shape== torch.Size([256])
    tcb0.3.weight   shape== torch.Size([256, 512, 3, 3])
    tcb0.3.bias   shape== torch.Size([256])
    tcb0.5.weight   shape== torch.Size([256, 256, 3, 3])
    tcb0.5.bias   shape== torch.Size([256])
    tcb0.6.weight   shape== torch.Size([256, 1024, 3, 3])
    tcb0.6.bias   shape== torch.Size([256])
    tcb0.8.weight   shape== torch.Size([256, 256, 3, 3])
    tcb0.8.bias   shape== torch.Size([256])
    tcb0.9.weight   shape== torch.Size([256, 512, 3, 3])
    tcb0.9.bias   shape== torch.Size([256])
    tcb0.11.weight   shape== torch.Size([256, 256, 3, 3])
    tcb0.11.bias   shape== torch.Size([256])
    tcb1.0.weight   shape== torch.Size([256, 256, 2, 2])
    tcb1.0.bias   shape== torch.Size([256])
    tcb1.1.weight   shape== torch.Size([256, 256, 2, 2])
    tcb1.1.bias   shape== torch.Size([256])
    tcb1.2.weight   shape== torch.Size([256, 256, 2, 2])
    tcb1.2.bias   shape== torch.Size([256])
    tcb2.1.weight   shape== torch.Size([256, 256, 3, 3])
    tcb2.1.bias   shape== torch.Size([256])
    tcb2.4.weight   shape== torch.Size([256, 256, 3, 3])
    tcb2.4.bias   shape== torch.Size([256])
    tcb2.7.weight   shape== torch.Size([256, 256, 3, 3])
    tcb2.7.bias   shape== torch.Size([256])
    tcb2.10.weight   shape== torch.Size([256, 256, 3, 3])
    tcb2.10.bias   shape== torch.Size([256])
    

    可以看到,我们自定义的网络和加载的pth层名有些许不一样,比如net里面的层名

    tcb0_2.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb0_2.bias   shape:: torch.Size([256])
    tcb2_1.weight   shape:: torch.Size([256, 256, 3, 3])
    tcb2_1.bias   shape:: torch.Size([256])
    

    对应pth层名

    tcb0.2.weight   shape== torch.Size([256, 256, 3, 3])
    tcb0.2.bias   shape== torch.Size([256])
    tcb2.1.weight   shape== torch.Size([256, 256, 3, 3])
    tcb2.1.bias   shape== torch.Size([256])
    

    还有其他的,我现在要做的就是把pth的层名改成和net所需要的层名一致,才能加载到net中。对应代码如下:

        checkpoint = torch.load(path_model, map_location=torch.device('cpu'))
        import collections
        new_state_dict = collections.OrderedDict()
    
        for k, v in checkpoint.items():
            name = k.replace('vgg.', 'conv')
            name = name.replace('extras.', 'extras')
            if "arm" in name or "loc" in name or "tcb" in name or "odm" in name:
                name = name.replace(".","_",1)
            new_state_dict[name] = v
    
        net.load_state_dict(new_state_dict, strict=False)
    

    这样的话就保证了两边加载了一样的权重了。

    然后在tensorrt里面搭建对应的操作,类似代码如下

    // Creat the engine using only the API and not any parser.
    ICudaEngine* createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) {
        INetworkDefinition* network = builder->createNetworkV2(0U);
    
        ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims3{3, INPUT_H, INPUT_W});
        assert(data);
    
        std::map<std::string, Weights> weightMap = loadWeights(path_wts);
        Weights emptywts{DataType::kFLOAT, nullptr, 0};
        DimsHW maxpool_hw = DimsHW(2,2);
    
        auto lr0 = convRelu(network, weightMap, *data, 64, 3, 1, 1, 0);
        auto lr1 = convRelu(network, weightMap, *lr0->getOutput(0), 64, 3, 1, 1, 2);
        IPoolingLayer* pool1 = network->addPoolingNd(*lr1->getOutput(0), PoolingType::kMAX, DimsHW{2, 2});
        assert(pool1);
        pool1->setStrideNd(DimsHW{2, 2});
    
        auto lr2 = convRelu(network, weightMap, *pool1->getOutput(0), 128, 3, 1, 1, 5);
        auto lr3 = convRelu(network, weightMap, *lr2->getOutput(0), 128, 3, 1, 1, 7);
        IPoolingLayer* pool2 = network->addPoolingNd(*lr3->getOutput(0), PoolingType::kMAX, DimsHW{2, 2});
        assert(pool2);
        pool2->setStrideNd(DimsHW{2, 2});
    
        auto lr4 = convRelu(network, weightMap, *pool2->getOutput(0), 256, 3, 1, 1, 10);
        auto lr5 = convRelu(network, weightMap, *lr4->getOutput(0), 256, 3, 1, 1, 12);
        auto lr6 = convRelu(network, weightMap, *lr5->getOutput(0), 256, 3, 1, 1, 14);
        IPoolingLayer* pool3 = network->addPoolingNd(*lr6->getOutput(0), PoolingType::kMAX, DimsHW{2, 2});
        assert(pool3);
        pool3->setStrideNd(DimsHW{2, 2});
    
        auto lr7 = convRelu(network, weightMap, *pool3->getOutput(0), 512, 3, 1, 1, 17);
        auto lr8 = convRelu(network, weightMap, *lr7->getOutput(0), 512, 3, 1, 1, 19);
        auto lr9 = convRelu(network, weightMap, *lr8->getOutput(0), 512, 3, 1, 1, 21);
        IPoolingLayer* pool4 = network->addPoolingNd(*lr9->getOutput(0), PoolingType::kMAX, DimsHW{2, 2});
        assert(pool4);
        pool4->setStrideNd(DimsHW{2, 2});
    
        auto lr24 = convRelu(network, weightMap, *pool4->getOutput(0), 512, 3, 1, 1, 24);
        auto lr26 = convRelu(network, weightMap, *lr24->getOutput(0), 512, 3, 1, 1, 26);
        auto lr28 = convRelu(network, weightMap, *lr26->getOutput(0), 512, 3, 1, 1, 28);
    

    4.验证精度

    tensorrt比较麻烦,需要先生成engine,然后本地再加载engine推理。
    在搭建模型的时候,想验证哪一层就让这一层输出。

        auto lr0 = convRelu(network, weightMap, *data, 64, 3, 1, 1, 0);
    
        lr0->getOutput(0)->setName("out");
        network->markOutput(*lr0->getOutput(0));
    
    
        builder->setMaxBatchSize(maxBatchSize);
        config->setMaxWorkspaceSize(16 * (1 << 20));  // 16MB
        ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
        std::cout << "Build engine successfully!" << std::endl;
        // Don't need the network any more
        network->destroy();
        // Release host memory
        for (auto& mem : weightMap)
        {
            free((void*) (mem.second.values));
        }
    
        return engine;
    

    后续代码会本地生成engine,然后加载跑推理,但是需要改,推理是需要指定输出的大小的。具体见下面代码

    void doInference(IExecutionContext& context, float* input, float* output, int batchSize) {
        const ICudaEngine& engine = context.getEngine();
    
        // Pointers to input and output device buffers to pass to engine.
        // Engine requires exactly IEngine::getNbBindings() number of buffers.
        std::cout<<"engine.getNbBindings()==="<<engine.getNbBindings()<<std::endl;
        assert(engine.getNbBindings() == 2);//getNbBindings()输入和输出的个数
        void* buffers[2];
    
        // In order to bind the buffers, we need to know the names of the input and output tensors.
        // Note that indices are guaranteed to be less than IEngine::getNbBindings()
        const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
        const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
        printf("inputIndex=%d
    ",inputIndex);
        printf("outputIndex=%d
    ",outputIndex);
    
        // Create GPU buffers on device
        CUDA_CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * INPUT_H_refinedet * INPUT_W_refinedet * sizeof(float)));
    
        const int OUTPUT_SIZE_2 = 1 * 64 * 160 * 160; //大小需要自己知道写这里  分配显存
        CUDA_CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE_2 * sizeof(float)));
    
        // Create stream
        cudaStream_t stream;
        CUDA_CHECK(cudaStreamCreate(&stream));
    
        // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
        CUDA_CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H_refinedet * INPUT_W_refinedet * sizeof(float), cudaMemcpyHostToDevice, stream));
        context.enqueue(batchSize, buffers, stream, nullptr); //推理!!!
        CUDA_CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE_2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); //输出从cuda给到cpu
        cudaStreamSynchronize(stream);
    
        // Release stream and buffers
        cudaStreamDestroy(stream);
        CUDA_CHECK(cudaFree(buffers[inputIndex]));
        CUDA_CHECK(cudaFree(buffers[outputIndex]));
    }
    

    然后函数外面就可以output看输出的值和pytorch对比。看看是不是一致的。应该是一致的。

    5.不支持的层实现

    对于不支持的层需要自己实现,写plugin或者用既有的api实现。
    这里我耗费了几天。见下面博客。
    【L2norm 层tensorrt api实现】(https://www.cnblogs.com/yanghailin/p/14448829.html)

    还有一个softmax也耗费了我大半天时间:
    【指定维度softmax 层tensorRT api实现】(https://www.cnblogs.com/yanghailin/p/14486077.html)

    6.后处理

    后处理是个大坑。
    也不叫大坑吧,是有难度。因为需要用到cuda编程。需要把任务分解成可并行的格式来做。
    我看到的yolov3里面就是自己写了个plugin来做后处理。
    我一开始也准备这么搞的,之前也看了centernet的tensorrt实现。
    https://github.com/CaoWGG/TensorRT-CenterNet
    https://github.com/CaoWGG/TensorRT-CenterNet/blob/f949252e37b51e60f873808f46d3683f15735e79/src/ctdetNet.cpp#L146
    这份代码也比较优秀,基于tensorrt5.0实现的。这份代码是通过tensorrt推理出来cuda的数组,然后自己写的cuda函数完成后处理.
    然后我就想,我也可以这么做,写plugin要按照好多格式太麻烦了。
    这里还有个问题是refinedet网络出来的是4个tensor,如果要4个输出的话没有看到有人这么做过,但是后面我自己摸索,是可以输出多个tensor的。

        arm_loc->getOutput(0)->setName(OUTPUT_BLOB_NAME_arm_loc);
        network->markOutput(*arm_loc->getOutput(0));
    
      arm_conf_111->getOutput(0)->setName(OUTPUT_BLOB_NAME_arm_conf);
        network->markOutput(*arm_conf_111->getOutput(0));
    
     odm_loc->getOutput(0)->setName(OUTPUT_BLOB_NAME_odm_loc);
        network->markOutput(*odm_loc->getOutput(0));
    
     odm_conf->getOutput(0)->setName(OUTPUT_BLOB_NAME_odm_conf);
        network->markOutput(*odm_conf->getOutput(0));
    

    定义网络的时候直接多个输出。然后推理的时候,engine.getNbBindings()就是5,context.enqueue(batchSize, buffers, stream, nullptr);这里的buffer就是5维的数据

    void doInference(IExecutionContext& context, void* buffers[], cudaStream_t &stream, float* input, std::vector<std::vector<float>> &detections) {
        auto start_infer = std::chrono::system_clock::now();
        detections.clear();
        int batchSize = 1;
        const ICudaEngine& engine = context.getEngine();
    
        // Pointers to input and output device buffers to pass to engine.
        // Engine requires exactly IEngine::getNbBindings() number of buffers.
    //    std::cout<<"engine.getNbBindings()==="<<engine.getNbBindings()<<std::endl;
        assert(engine.getNbBindings() == 5);
    
        // In order to bind the buffers, we need to know the names of the input and output tensors.
        // Note that indices are guaranteed to be less than IEngine::getNbBindings()
        const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
        const int outputIndex_arm_loc = engine.getBindingIndex(OUTPUT_BLOB_NAME_arm_loc);
        const int outputIndex_arm_conf = engine.getBindingIndex(OUTPUT_BLOB_NAME_arm_conf);
        const int outputIndex_odm_loc = engine.getBindingIndex(OUTPUT_BLOB_NAME_odm_loc);
        const int outputIndex_odm_conf = engine.getBindingIndex(OUTPUT_BLOB_NAME_odm_conf);
    //    const int outputIndex2 = engine.getBindingIndex("prob2");
    //    printf("inputIndex=%d
    ",inputIndex);
    //    printf("outputIndex_arm_loc=%d
    ",outputIndex_arm_loc);
    //    printf("outputIndex_arm_conf=%d
    ",outputIndex_arm_conf);
    //    printf("outputIndex_odm_loc=%d
    ",outputIndex_odm_loc);
    //    printf("outputIndex_odm_conf=%d
    ",outputIndex_odm_conf);
    
        // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
        CUDA_CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream));
        context.enqueue(batchSize, buffers, stream, nullptr);
    

    context.enqueue(batchSize, buffers, stream, nullptr);
    所以其实到这里为止,我就得到了4个在cuda上面的tensor,并且数值和pytorch里面是一致的。
    看cuda编程看了几天,对应我们这个任务还是无从下手。请教了大神,说尽快和我讨论。然后就没有下文了。
    没有办法,我仔细盯着pytorch的后处理实现

     def forward(self, num_classes, size, bkg_label, top_k, conf_thresh, nms_thresh,
                    objectness_thre, keep_top_k,arm_loc_data, arm_conf_data, odm_loc_data, odm_conf_data, prior_data):
            """
            Args:
                loc_data: (tensor) Loc preds from loc layers
                    Shape: [batch,num_priors*4]
                conf_data: (tensor) Shape: Conf preds from conf layers
                    Shape: [batch*num_priors,num_classes]
                prior_data: (tensor) Prior boxes and variances from priorbox layers
                    Shape: [1,num_priors,4]
    
                    arm_loc_data: torch.Size([1, 6375, 4])
                    arm_conf_data: torch.Size([1, 6375, 2])
                    odm_loc_data: torch.Size([1, 6375, 4])
                    odm_conf_data: torch.Size([1, 6375, 25])
                    prior_data: torch.Size([6375, 4])
            """
    
            self.num_classes = num_classes
            self.background_label = bkg_label
            self.top_k = top_k
            self.keep_top_k = keep_top_k
            # Parameters used in nms.
            self.nms_thresh = nms_thresh
            if nms_thresh <= 0:
                raise ValueError('nms_threshold must be non negative.')
            self.conf_thresh = conf_thresh  ## 0.01
            self.objectness_thre = objectness_thre  ## 0.01
            self.variance = cfg[str(size)]['variance']
    
    
            loc_data = odm_loc_data  #[1, 6375, 4]
            conf_data = odm_conf_data #[1, 6375, 25]
            # [1,6375,1]             #[1,6375,2]    --->>> [1,6375,1]
            arm_object_conf = arm_conf_data.data[:, :, 1:]
            # [1,6375,1]
            no_object_index = arm_object_conf <= self.objectness_thre
            conf_data[no_object_index.expand_as(conf_data)] = 0  ##[1, 6375, 25]
    
            num = loc_data.size(0)  # 1 batch size: 1
            num_priors = prior_data.size(0) #6375
            output = torch.zeros(num, self.num_classes, self.top_k, 5)##[1,25,1000,5]
            conf_preds = conf_data.view(num, num_priors,
                                        self.num_classes).transpose(2, 1)
            ####[1,25,6375]
    
            # Decode predictions into bboxes.
            for i in range(num):
                #[6375,4]          [6375,4]        [6375,4]
                default = decode(arm_loc_data[i], prior_data, self.variance)
    
                default = center_size(default)
                #[6375,4]                       #[6375,4]    [6375,4]
                decoded_boxes = decode(loc_data[i], default, self.variance)
                #[25,6375]   For each class, perform nms
                conf_scores = conf_preds[i].clone()
                #print(decoded_boxes, conf_scores)
                for cl in range(1, self.num_classes):
                    c_mask = conf_scores[cl].gt(self.conf_thresh)
                    scores = conf_scores[cl][c_mask]
                    #print(scores.dim())
                    if scores.size(0) == 0:
                        continue
                    l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                    boxes = decoded_boxes[l_mask].view(-1, 4)
                    #或者boxes = decoded_boxes[c_mask].view(-1, 4)
    
                    # idx of highest scoring and non-overlapping boxes per class
                    #print(boxes, scores)
                    ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
                    output[i, cl, :count] = 
                        torch.cat((scores[ids[:count]].unsqueeze(1),
                                   boxes[ids[:count]]), 1)
            flt = output.contiguous().view(num, -1, 5)
            _, idx = flt[:, :, 0].sort(1, descending=True)
            _, rank = idx.sort(1)
            flt[(rank < self.keep_top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
            return output
    

    挺复杂的,都是坐标索引,维度变换,想用cuda实现,想想就头疼。cuda上面各个tensor就是一维数组。这咋搞。
    然后有一天盯着这段代码,抱着试试的心态,把cuda数据看看能不能转libtorch,因为我libtorch的实现是已经写好的。
    然后在cmakelist里面添加libtorch库。

     assert(engine.getNbBindings() == 5);
    
        // In order to bind the buffers, we need to know the names of the input and output tensors.
        // Note that indices are guaranteed to be less than IEngine::getNbBindings()
        const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
        const int outputIndex_arm_loc = engine.getBindingIndex(OUTPUT_BLOB_NAME_arm_loc);
        const int outputIndex_arm_conf = engine.getBindingIndex(OUTPUT_BLOB_NAME_arm_conf);
        const int outputIndex_odm_loc = engine.getBindingIndex(OUTPUT_BLOB_NAME_odm_loc);
        const int outputIndex_odm_conf = engine.getBindingIndex(OUTPUT_BLOB_NAME_odm_conf);
    //    const int outputIndex2 = engine.getBindingIndex("prob2");
    //    printf("inputIndex=%d
    ",inputIndex);
    //    printf("outputIndex_arm_loc=%d
    ",outputIndex_arm_loc);
    //    printf("outputIndex_arm_conf=%d
    ",outputIndex_arm_conf);
    //    printf("outputIndex_odm_loc=%d
    ",outputIndex_odm_loc);
    //    printf("outputIndex_odm_conf=%d
    ",outputIndex_odm_conf);
    
        // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
        CUDA_CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream));
        context.enqueue(batchSize, buffers, stream, nullptr);
       
        int m_prior_size = 6375;
        torch::Tensor arm_loc = torch::from_blob(buffers[outputIndex_arm_loc],{m_prior_size,4}).cuda().toType(torch::kFloat64).unsqueeze(0);
        torch::Tensor arm_conf = torch::from_blob(buffers[outputIndex_arm_conf],{m_prior_size,2}).cuda().toType(torch::kFloat64).unsqueeze(0);
        torch::Tensor odm_loc = torch::from_blob(buffers[outputIndex_odm_loc],{m_prior_size,4}).cuda().toType(torch::kFloat64).unsqueeze(0);
        torch::Tensor odm_conf = torch::from_blob(buffers[outputIndex_odm_conf],{m_prior_size,25}).cuda().toType(torch::kFloat64).unsqueeze(0);
    

    居然可以啊!!!
    厉害了,试了这里可以访问torch::Tensor arm_loc里面数据,并且是对的!
    然后我把后处理代码搬过来,是可以跑通出效果图的!
    至此,后处理就可以用libtorch,摆脱cuda编程了!

    7.性能比较


    int8会多检测出很多框,但是评价脚本对多出框这没有惩罚,所以导致精度差不多。
    int8多出框导致后处理时间变长。

    小弟不才,同时谢谢友情赞助!

    好记性不如烂键盘---点滴、积累、进步!
  • 相关阅读:
    平衡二叉树的遍历/删除/新增/维护平衡因子
    二分查找算法(加法方式:斐波那契查找)
    Ubuntu14 配置开机自启动/关闭
    JAVA & JSON详解
    jQuery---EasyUI小案列
    jquery---基本标签
    NoSql---MongoDB基本操作
    Java框架篇---Mybatis 构建SqlSessionFactory
    Java框架篇---Mybatis 入门
    hessian入门
  • 原文地址:https://www.cnblogs.com/yanghailin/p/14525128.html
Copyright © 2020-2023  润新知