• YOLOV5——计算预测数据的精确率(precision)和召回率(recall)


    1、预测数据

      使用 yolov5 预测输出的是带标记框的图片,所以需要先将预测的数据输出为文本格式,如json。

      可以参考我之前的博客 https://www.cnblogs.com/yxyun/p/14250455.html

    json文件格式如下图

    [
        {
            "name":"01_05_0002.jpg",
            "category":"1",
            "bbox":[
                3550,
                1813,
                4106,
                2468
            ],
            "score":0.9482421875
        },
        {
            "name":"01_05_0002.jpg",
            "category":"1",
            "bbox":[
                4041,
                1655,
                4570,
                2291
            ],
            "score":0.9521484375
        }
    ]
    

    2、测试图的标记数据

      因为使用的是 yolov5 ,所以测试集的标记数据和训练集一样都是 txt 格式的标注数据

    <object-class> <x_center> <y_center> <width> <height>
    

    3、将预测数据和标注数据匹配

      一个测试图中可能存在好几个目标物,那预测数据和标注数据中就会存在好几个标注框,计算两个框的交集面积,交集面积最大的两个框才是互相匹配的。

    计算两个矩形框的交集面积

    def label_area_detect(label_bbox_list, detect_bbox_list):
        x_label_min, y_label_min, x_label_max, y_label_max = label_bbox_list
        x_detect_min, y_detect_min, x_detect_max, y_detect_max = detect_bbox_list
        if (x_label_max <= x_detect_min or x_detect_max < x_label_min) or ( y_label_max <= y_detect_min or y_detect_max <= y_label_min):
            return 0
        else:
            lens = min(x_label_max, x_detect_max) - max(x_label_min, x_detect_min)
            wide = min(y_label_max, y_detect_max) - max(y_label_min, y_detect_min)
            return lens * wide

    4、多分类 precision 和 recall 的计算

       将预测数据和标注数据匹配算出一个 N*N 的矩阵

    4、代码

    import os
    import json
    import  numpy as np
    from PIL import Image
    
    # class name
    classes = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12']
    # 初始化二维0数组
    result_list = np.array(np.zeros([len(classes), len(classes)]))
    
    # 获取图片宽高
    def get_image_width_high(full_image_name):
        image = Image.open(full_image_name)
        image_width, image_high = image.size[0], image.size[1]
        return image_width, image_high
    
    
    # 读取原始标注数据
    def read_label_txt(full_label_name, full_image_name):
        fp = open(full_label_name, mode="r")
        lines = fp.readlines()
        image_width, image_high = get_image_width_high(full_image_name)
        object_list = []
        for line in lines:
            array = line.split()
            x_label_min = (float(array[1]) - float(array[3]) / 2) * image_width
            x_label_max = (float(array[1]) + float(array[3]) / 2) * image_width
            y_label_min = (float(array[2]) - float(array[4]) / 2) * image_high
            y_label_max = (float(array[2]) + float(array[4]) / 2) * image_high
            bbox = [round(x_label_min, 2), round(y_label_min, 2), round(x_label_max, 2), round(y_label_max, 2)]
            category = int(array[0])
            obj_info = {
                'category' : category,
                'bbox' : bbox
            }
            object_list.append(obj_info)
        return object_list
    
    
    # 计算交集面积
    def label_area_detect(label_bbox_list, detect_bbox_list):
        x_label_min, y_label_min, x_label_max, y_label_max = label_bbox_list
        x_detect_min, y_detect_min, x_detect_max, y_detect_max = detect_bbox_list
        if (x_label_max <= x_detect_min or x_detect_max < x_label_min) or ( y_label_max <= y_detect_min or y_detect_max <= y_label_min):
            return 0
        else:
            lens = min(x_label_max, x_detect_max) - max(x_label_min, x_detect_min)
            wide = min(y_label_max, y_detect_max) - max(y_label_min, y_detect_min)
            return lens * wide
    
    # label 匹配 detect
    def label_match_detect(image_name, label_list, detect_list):
        for label in label_list:
            area_max = 0
            area_category = 0
            label_category = label['category']
            label_bbox = label['bbox']
            for detect in detect_list:
                if detect['name'] == image_name:
                    detect_bbox = detect['bbox']
                    area = label_area_detect(label_bbox, detect_bbox)
                    if area > area_max:
                        area_max = area
                        area_category = detect['category']
    
            result_list[int(label_category)][classes.index(str(area_category))] += 1
    
    
    def main():
        image_path = '../image_data/seed/test/images/'  # 图片文件路径
        label_path = '../image_data/seed/test/labels/'  # 标注文件路径
        detect_path = 'result.json'    # 预测的数据
        precision = 0     # 精确率
        recall = 0        # 召回率
        # 读取 预测 文件数据
        with open(detect_path, 'r') as load_f:
            detect_list = json.load(load_f)
        # 读取图片文件数据
        all_image = os.listdir(image_path)
        for i in range(len(all_image)):
            full_image_path = os.path.join(image_path, all_image[i])
            # 分离文件名和文件后缀
            image_name, image_extension = os.path.splitext(all_image[i])
            # 拼接标注路径
            full_label_path = os.path.join(label_path, image_name + '.txt')
            # 读取标注数据
            label_list = read_label_txt(full_label_path, full_image_path)
            # 标注数据匹配detect
            label_match_detect(all_image[i], label_list, detect_list)
        # print(result_list)
        for i in range(len(classes)):
            row_sum, col_sum = sum(result_list[i]), sum(result_list[r][i] for r in range(len(classes)))
            precision += result_list[i][i] / float(col_sum)
            recall += result_list[i][i] / float(row_sum)
        print(f'precision: {precision / len(classes) * 100}%  recall: {recall / len(classes) * 100}%')
    
    
    if __name__ == '__main__':
        main()

      输出:

     

  • 相关阅读:
    pgsql 记录
    tomcat下放两个spring boot项目
    nigex 反向代理
    tomcat7里放springboot
    postgresql 建表语句
    从最大似然到EM算法浅解(转载)
    深度学习资料
    【vuejs小项目——vuejs2.0版本】组件化的开发方式
    【vuejs小项目——vuejs2.0版本】单页面搭建
    如何关闭eslint
  • 原文地址:https://www.cnblogs.com/yxyun/p/14544151.html
Copyright © 2020-2023  润新知