• psenet的eval_ctw1500.py解析


    评价模型指标,自带的eval_ctw1500.py是根据两两框重叠面积大于0.5算正例tp。其中用到了import Polygon as plg 模块,可以方便的处理多边形的重叠计算面积
    1.pip install Polygon2
    2.cover = set() 这里用到了set,之前在c++里面也看到过这个的,一查发现差不多功能,就是没有重复的无序的

    file_util.py

    import os
    
    def read_dir(root):
    	file_path_list = []
    	for file_path, dirs, files in os.walk(root):
    		for file in files:
    			file_path_list.append(os.path.join(file_path, file).replace('\', '/'))
    	file_path_list.sort()
    	return file_path_list
    
    def read_file(file_path):
    	file_object = open(file_path, 'r')
    	file_content = file_object.read()
    	file_object.close()
    	return file_content
    
    def write_file(file_path, file_content):
    	if file_path.find('/') != -1:
    		father_dir = '/'.join(file_path.split('/')[0:-1])
    		if not os.path.exists(father_dir):
    			os.makedirs(father_dir)
    	file_object = open(file_path, 'w')
    	file_object.write(file_content)
    	file_object.close()
    
    
    def write_file_not_cover(file_path, file_content):
    	father_dir = '/'.join(file_path.split('/')[0:-1])
    	if not os.path.exists(father_dir):
    		os.makedirs(father_dir)
    	file_object = open(file_path, 'a')
    	file_object.write(file_content)
    	file_object.close()
    

    eval_ctw1500.py

    import file_util
    import Polygon as plg
    import numpy as np
    
    pred_root = '../../outputs/submit_ctw1500/'
    gt_root = '../../data/CTW1500/test/text_label_curve/'
    
    def get_pred(path):
        lines = file_util.read_file(path).split('
    ')
        bboxes = []
        for line in lines:
            if line == '':
                continue
            bbox = line.split(',')
            if len(bbox) % 2 == 1:
                print path
            bbox = [(int)(x) for x in bbox]
            bboxes.append(bbox)
        return bboxes
    
    def get_gt(path):
        print("###############################path=%s"%(path))
        lines = file_util.read_file(path).split('
    ')
        bboxes = []
        for line in lines:
            if line == '':
                continue
            # line = util.str.remove_all(line, 'xefxbbxbf')
            # gt = util.str.split(line, ',')
            gt = line.split(',')
    
            x1 = np.int(gt[0])
            y1 = np.int(gt[1])
    
            bbox = [np.int(gt[i]) for i in range(4, 32)]
            bbox = np.asarray(bbox) + ([x1, y1] * 14)
            
            bboxes.append(bbox)
        return bboxes
    
    def get_union(pD,pG):
        areaA = pD.area();
        areaB = pG.area();
        return areaA + areaB - get_intersection(pD, pG);        
    
    def get_intersection(pD,pG):
        pInt = pD & pG
        if len(pInt) == 0:
            return 0
        return pInt.area()
    
    if __name__ == '__main__':
        th = 0.5
        pred_list = file_util.read_dir(pred_root)
    
        tp, fp, npos = 0, 0, 0
        
        for pred_path in pred_list:
            preds = get_pred(pred_path)
            gt_path = gt_root + pred_path.split('/')[-1]
            gts = get_gt(gt_path)
            npos += len(gts)
            
            cover = set()
            for pred_id, pred in enumerate(preds):
                pred = np.array(pred)
                pred = pred.reshape(pred.shape[0] / 2, 2)
                # if pred.shape[0] <= 2:
                #     continue
                pred_p = plg.Polygon(pred)
                
                flag = False
                for gt_id, gt in enumerate(gts):
                    gt = np.array(gt)
                    gt = gt.reshape(gt.shape[0] / 2, 2)
                    gt_p = plg.Polygon(gt)
    
                    union = get_union(pred_p, gt_p)
                    inter = get_intersection(pred_p, gt_p)
    
                    if inter * 1.0 / union >= th:
                        if gt_id not in cover:
                            flag = True
                            cover.add(gt_id)
                if flag:
                    tp += 1.0
                else:
                    fp += 1.0
    
        print tp, fp, npos
        precision = tp / (tp + fp)
        recall = tp / npos
        hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall)
    
        print 'p: %.4f, r: %.4f, f: %.4f'%(precision, recall, hmean)
    
    
  • 相关阅读:
    js 构造函数 constructor
    js foreach和map区别
    js 静态方法和实例方法
    学习知识点总结(es6篇)
    java1.5新特性(转)
    21 Managing the Activity Lifecycle
    Java进阶Collection集合框架概要·16
    Java进阶核心之集合框架Map下集·18
    Java进阶核心之集合框架Set·19
    Java进阶核心之集合框架List·17
  • 原文地址:https://www.cnblogs.com/yanghailin/p/11139080.html
Copyright © 2020-2023  润新知