• [Paddle学习笔记][13][基于YOLOv3的昆虫检测-测试模型]


    说明:

    本例程使用YOLOv3进行昆虫检测。例程分为数据处理、模型设计、损失函数、训练模型、模型预测和测试模型六个部分。本篇为第六部分,保存非极大值抑制输出的结果到预测结果文件,然后通过完整插值方法计算mAP。非极大值阈值的预测得分需要设置一个低的得分,使得计算mAP时能比较更多的平均精度。

    实验代码:

    测试模型:

    import json
    import paddle.fluid as fluid
    from paddle.fluid.dygraph.base import to_variable
    
    from source.data import multip_test_reader
    from source.model import YOLOv3
    from source.infer import get_nms_infer
    from source.test import test
    
    num_classes = 7                                                                              # 类别数量
    anchor_size = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] # 锚框大小
    anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]                                              # 锚框掩码
    downsample_ratio = 32                                                                        # 下采样率
    
    test_path = './dataset/val/images'         # 测试目录路径
    json_path = './output/infer.json'          # 结果保存路径
    xmls_path='./dataset/val/annotations/xmls' # 标签目录路径
    model_path = './output/darknet53-yolov3'   # 网络权重路径
    sco_threshold = 0.01                       # 预测得分阈值:设置一个小值,使得测试能够比较更多的准确率
    nms_threshold = 0.45                       # 非极大值阈值:消除重叠大于该阈值的的预测边框
    iou_threshold = 0.50                       # 测试交并比值:保留与真实边框大于该阈值的预测边框
    
    with fluid.dygraph.guard():
        # 准备数据
        test_reader = multip_test_reader(test_path, batch_size=8, scale_size=(608, 608))
        
        # 加载模型
        model = YOLOv3(num_classes=num_classes, anchor_mask=anchor_mask) # 加载模型
        model_dict, _ = fluid.load_dygraph(model_path)                   # 加载权重
        model.load_dict(model_dict)                                      # 设置权重
        model.eval()                                                     # 设置验证
        
        # 模型预测
        infer_list = [] # 预测结果列表
        
        for test_data in test_reader():
            # 读取图像
            image_name, image, image_size = test_data # 读取数据
            image = to_variable(image)                # 转换格式
            image_size = to_variable(image_size)      # 转换格式
            
            # 前向传播
            infer = model(image)
            
            # 获取结果
            infer = get_nms_infer(infer, image_size, num_classes, anchor_size, anchor_mask, downsample_ratio, 
                                  sco_threshold, nms_threshold)
            
            # 添加列表
            for i in range(len(infer)): # 遍历批次
                if(len(infer[i]) > 0):  # 是否存在物体
                    infer_list.append([image_name[i], infer[i].tolist()])
            print('Processed {} images...'.format(len(infer_list)), end='
    ')
            
        # 保存结果
        print('Svae {} results to infer.json.'.format(len(infer_list)))
        json.dump(infer_list, open(json_path, 'w'))
        
        # 测试模型
        test(json_path, xmls_path, num_classes, iou_threshold)

    结果:

    Svae 245 results to infer.json

    Detection mAP(0.50) = 87.63%

    测试结果

    darknet53-yolov3_050 Detection mAP(0.50) = 64.87%

    darknet53-yolov3_100 Detection mAP(0.50) = 81.02%

    darknet53-yolov3_150 Detection mAP(0.50) = 87.63%

    test.py文件

    import os
    import json
    import math
    import numpy as np
    import xml.etree.ElementTree as ET
    
    # 计算平均精度
    class DetectionMAP(object):
        def __init__(self, num_classes, iou_threshold=0.5):
            """
            功能: 
                初始化计算平均精度方法
            输入: 
                num_classes   - 预测类别数量
                iou_threshold - 测试交并比值
            输出:
            """
            self.num_classes = num_classes                     # 预测类别数量
            self.iou_threshold = iou_threshold                 # 测试交并比值
            self.count = [0] * self.num_classes                # 数量统计列表
            self.score = [[] for _ in range(self.num_classes)] # 得分统计列表
            
        def update(self, infer, gtbox, gtcls):
            """
            功能: 
                统计各类数量和得分
            输入: 
                infer - 预测结果
                gtbox - 物体边框
                gtcls - 物体类别
            输出:
            """
            # 统计各类数量
            for gtcls_item in gtcls:
                self.count[int(np.array(gtcls_item))] += 1
            
            # 统计各类得分
            visited = [False] * len(gtcls) # 各类访问标识
            for infer_item in infer:
                # 获取预测数据
                pdcls, pdsco, xmin, ymin, xmax, ymax = infer_item.tolist() # 获取预测数据
                pdbox = [xmin, ymin, xmax, ymax]                           # 获取预测边框
                
                # 计算最大边框
                max_index = -1 # 最大交并索引
                max_iou = -1.0 # 最大交并比值
                for i, gtcls_item in enumerate(gtcls): # 遍历真实类别列表
                    if int(gtcls_item) == int(pdcls): # 如果真实类别等于预测类别,则计算交并比值
                        iou = self.get_box_iou_xyxy(pdbox, gtbox[i])
                        if iou > max_iou: # 如果交并比值大于最大交并比值,则更新最大交并比值和索引
                            max_index = i
                            max_iou = iou
                
                # 统计各类得分
                if max_iou > self.iou_threshold: # 如果最大交并比值大于测试交并比值
                    if not visited[max_index]: # 如果该物体没有被统计,则添加到列表,并设置访问标识
                        self.score[int(pdcls)].append([pdsco, 1.0]) # 添加各类正确正例
                        visited[max_index] = True                   # 设置访问标识为真
                    else: # 如果该物体已经被统计,则添加到列表,并设置为成错误正例
                        self.score[int(pdcls)].append([pdsco, 0.0]) # 添加各类错误正例
                else: # 如果最大交并比值不大于测试交并比值,则添加到列表,并设置成错误正例
                    self.score[int(pdcls)].append([pdsco, 0.0])     # 添加各类错误正例
            
        def get_box_iou_xyxy(self, box1, box2):
            """
            功能: 
                计算边框交并比值
            输入: 
                box1 - 边界框1
                box2 - 边界框2
            输出:
                iou  - 交并比值
            """
            # 计算交集面积
            x1_min, y1_min, x1_max, y1_max = box1[0], box1[1], box1[2], box1[3]
            x2_min, y2_min, x2_max, y2_max = box2[0], box2[1], box2[2], box2[3]
    
            x_min = np.maximum(x1_min, x2_min)
            y_min = np.maximum(y1_min, y2_min)
            x_max = np.minimum(x1_max, x2_max)
            y_max = np.minimum(y1_max, y2_max)
    
            w = np.maximum(x_max - x_min + 1.0, 0)
            h = np.maximum(y_max - y_min + 1.0, 0)
    
            intersection = w * h # 交集面积
    
            # 计算并集面积
            s1 = (y1_max - y1_min + 1.0) * (x1_max - x1_min + 1.0)
            s2 = (y2_max - y2_min + 1.0) * (x2_max - x2_min + 1.0)
    
            union = s1 + s2 - intersection # 并集面积
    
            # 计算交并比
            iou = intersection / union
    
            return iou
        
        def get_mAP(self):
            """
            功能:
                计算各类平均精度
            输入:
            输出:
                mAP - 各类平均精度
            """
            # 计算每类精度
            mAP = 0 # 各类平均精度
            cnt = 0 # 各类类别计数
            for score, count in zip(self.score, self.count): # 遍历每类物体
                # 统计正误正例
                if count == 0 or len(score) == 0: # 如果该类数量为0,或得分列表为空,则继续下一个类别
                    continue
                tp_list, fp_list = self.get_tp_fp_list(score) # 统计正误正例
                
                # 计算预测的准确率和召回率
                precision = [] # 准确率列表
                recall = []    # 召回率列表
                for tp, fp in zip(tp_list, fp_list):
                    precision.append(float(tp) / (tp + fp)) # 添加准确率
                    recall.append(float(tp) / count)        # 添加召回率
                
                # 计算平均精度
                AP = 0.0         # 平均精度
                pre_recall = 0.0 # 前召回率
                for i in range(len(precision)): # 遍历正确率列表
                    recall_gap = math.fabs(recall[i] - pre_recall) # 计算召回率差值
                    if recall_gap > 1e-6: # 如果召回率改变,则计算平均精度,更新前召回率
                        AP += precision[i] * recall_gap # 累加平均精度
                        pre_recall = recall[i]          # 更新前召回率
                
                # 更新各类精度
                mAP += AP # 累加各类精度
                cnt += 1  # 增加类别计数
                
            # 计算平均精度
            mAP = (mAP / float(cnt)) if cnt > 0 else mAP
            
            return mAP
    
        def get_tp_fp_list(self, score):
            """
            功能:
                对得分列表进行从大到小排序,按排序统计正确正例和错误正例数量
            输入:
                score   - 得分列表
            输出:
                tp_list - 正确正例列表
                fp_list - 错误正例列表
            """
            tp = 0       # 正确正例数量
            fp = 0       # 错误正例数量
            tp_list = [] # 正确正例列表
            fp_list = [] # 错误正例列表
            
            score_list = sorted(score, key=lambda s: s[0], reverse=True) # 对得分列表按从大到小排序
            for (score, label) in score_list:
                tp += int(label)     # 统计正确正例
                tp_list.append(tp)   # 添加正确正例
                fp += 1 - int(label) # 统计错误正例
                fp_list.append(fp)   # 添加错误正例
            
            return tp_list, fp_list
        
    ##############################################################################################################
    
    object_names = ['Boerner', 'Leconte', 'Linnaeus', 'acuminatus', 'armandi', 'coleoptera', 'linnaeus'] # 物体名称
    def get_object_gtcls():
        """
        功能:
            将物体名称映射成物体类别
        输入:
        输出:
            object_gtcls - 物体类别
        """
        object_gtcls = {} # 物体类别字典
        for key, value in enumerate(object_names):
            object_gtcls[value] = key # 将物体名称映射成物体类别
        return object_gtcls
    
    def test(json_path, xmls_path, num_classes, iou_threshold):
        """
        功能:
            测试模型平均精度
        输入:
            json_path     - 预测结果路径
            xmls_path     - 标签目录路径
            num_classes   - 预测类别数量
            iou_threshold - 测试交并比值
        输出:
        """
        # 声明计算方法
        mAP = DetectionMAP(num_classes, iou_threshold)
        
        # 统计预测得分
        json_list = json.load(open(json_path))               # 读取预测结果
        for json_item in json_list: # 遍历预测结果
            # 读取预测文件
            image_name = str(json_item[0])                   # 读取文件名称
            infer = np.array(json_item[1]).astype('float32') # 读取预测结果
            
            # 读取标签文件
            tree = ET.parse(os.path.join(xmls_path, image_name + '.xml')) # 解析文件
            image_w = float(tree.find('size').find('width').text)         # 图像宽度
            image_h = float(tree.find('size').find('height').text)        # 图像高度
            
            object_list = tree.findall('object')                     # 物体列表
            gtbox = np.zeros((len(object_list), 4), dtype='float32') # 物体边框
            gtcls = np.zeros((len(object_list),  ), dtype='int32')   # 物体类别
            
            for i, object_item in enumerate(object_list):
                # 读取物体边框
                x_min = float(object_item.find('bndbox').find('xmin').text) # 物体边框x1
                y_min = float(object_item.find('bndbox').find('ymin').text) # 物体边框y1
                x_max = float(object_item.find('bndbox').find('xmax').text) # 物体边框x2
                y_max = float(object_item.find('bndbox').find('ymax').text) # 物体边框y2
                
                x_min = max(0.0, x_min)
                y_min = max(0.0, y_min)
                x_max = min(x_max, image_w - 1.0)
                y_max = min(y_max, image_h - 1.0)
                
                gtbox[i] = [x_min, y_min, x_max, y_max] # 设置物体边框
                
                # 读取物体类别
                object_name = object_item.find('name').text # 读取物体名称
                gtcls[i] = get_object_gtcls()[object_name]  # 将物体名称映射成物体类别
            
            # 统计预测得分
            mAP.update(infer, gtbox, gtcls)
            
        # 计算平均精度
        mAP_value = mAP.get_mAP() * 100 # 计算平均精度
        print("Detection mAP({:.2f}) = {:.2f}%".format(iou_threshold, mAP_value))

    参考资料:

    https://blog.csdn.net/qq_31511955/article/details/89022037

    https://blog.csdn.net/weixin_41278720/article/details/88774411

    https://blog.csdn.net/wc996789331/article/details/83785993

    https://blog.csdn.net/litt1e/article/details/88814417

    https://blog.csdn.net/litt1e/article/details/88852745

    https://blog.csdn.net/litt1e/article/details/88907542

    https://aistudio.baidu.com/aistudio/projectdetail/742781

    https://aistudio.baidu.com/aistudio/projectdetail/672017

    https://aistudio.baidu.com/aistudio/projectdetail/868589

    https://aistudio.baidu.com/aistudio/projectdetail/122277

  • 相关阅读:
    css之overflow注意事项,分析效果没有实现的原因及解决
    Leetcode- 299. Bulls and Cows
    Leetcode-234. Palindrome Linked List
    Leetcode-228 Summary Ranges
    Leetcode-190. Reverse Bits
    盒子模型的理解
    css各类伪元素总结以及清除浮动方法总结
    Leetcode-231. Power of Two
    Uncaught TypeError: __WEBPACK_IMPORTED_MODULE_0_vue__.default.user is not a
    git commit -m ''后报eslint错误
  • 原文地址:https://www.cnblogs.com/d442130165/p/13685806.html
Copyright © 2020-2023  润新知