• 计算正逆序


    1、利用 auc

    #encoding=utf8
    from itertools import groupby
    import sys
    def calc_auc_and_pnr_fast(label,pred):
        sample = zip(label,pred)
        ## 根据pred倒排
        sample_sorted = sorted(sample,key=lambda x: -x[1])
        pos = 0
        cnt = 0
        r_cnt = 0
        last_pred = 0
        for i in range(len(sample_sorted)):
            l, p = sample_sorted[i]
            if l == 1:
                pos += 1
            elif l == 0:
                cnt += pos # 截止目前,有pos个正样本比他大
                if (i != 0 and last_pred == p):
                    cnt -= 0.5
            last_pred = p
        n = len(label)
        negs = n - pos
        r_cnt = pos * negs - cnt
        auc = float(cnt) / float(pos * negs)
        pnr = float(cnt) / r_cnt
        return auc, pnr
    
    if __name__ == '__main__':
        for user, lines in groupby(sys.stdin, key=lambda x:x.split('\t')[0]):
            lines = list(lines)
            #print lines
            trues = [float(x.strip().split('\t')[1]) for x in lines]
            preds = [float(x.strip().split('\t')[2]) for x in lines]
            auc, pnr = calc_auc_and_pnr_fast(trues, preds)
            ## auc = 1/(1+1/pnr) ==> pnr = 1/ (1/a - 1)
            pnr_check = 1. / (1. / auc - 1 + 1e-9)
            print auc, pnr, pnr_check

    2、归并

    https://www.jianshu.com/p/e9813ac25cb6

    """
    inversecount
    """
    from itertools import groupby
    import sys
    
    
    class InversionCounter(object):
        """
        InversionCounter
        """
        @classmethod
        def merge_sort_count_sub(cls, vals):
            """
            merge_sort_count_sub
            """
            if sys.version > '3':
                if len(list(vals)) <= 1:
                    return vals, 0
            else:
                if len(vals) <= 1:
                    return vals, 0
    
            n = len(vals)
            left_vals, left_cnt = cls.merge_sort_count_sub(vals[:n / 2])
            right_vals, right_cnt = cls.merge_sort_count_sub(vals[n / 2:])
    
            left_i = 0
            right_i = 0
    
            mid_cnt = 0
            new_vals = []
            while True:
                if left_vals[left_i][1] <= right_vals[right_i][1]:
                    new_vals.append(left_vals[left_i])
                    left_i += 1
                elif left_vals[left_i][1] > right_vals[right_i][1]:
                    mid_cnt += (len(left_vals) - left_i)
                    new_vals.append(right_vals[right_i])
                    right_i += 1
    
                if left_i == len(left_vals):
                    new_vals.extend(right_vals[right_i:])
                    break
                if right_i == len(right_vals):
                    new_vals.extend(left_vals[left_i:])
                    break
    
            return new_vals, left_cnt + mid_cnt + right_cnt
    
    
        @classmethod
        def merge_sort_count_strict_right(cls, trues, preds):
            """
            merge_sort_count_strict_right
            """
            neg_preds = (-p for p in preds)
            vals = zip(trues, neg_preds)
            if sys.version > '3':
                sorted(vals)
            else :
                vals.sort()
            return cls.merge_sort_count_sub(vals)[1]
    
    
        @classmethod
        def merge_sort_count_strict_wrong(cls, trues, preds):
            """
            merge_sort_count_strict_wrong
            """
            vals = zip(trues, preds)
            if sys.version > '3':
                sorted(vals)
            else :
                vals.sort()
            return cls.merge_sort_count_sub(vals)[1]
    
    
        @classmethod
        def merge_sort_count_right(cls, trues, preds):
            """
            merge_sort_count_right
            """
            return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_wrong(trues, preds)
    
    
        @classmethod
        def merge_sort_count_wrong(cls, trues, preds):
            """
            merge_sort_count_wrong
            """
            return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_right(trues, preds)
    
    
        @classmethod
        def merge_sort_count_pair(cls, trues, preds=None):
            """
            preds: dummpy variable, no need inside function
            """
            trues = sorted(trues)
            acc_num = 0
            pair = 0
            for k, ks in groupby(trues):
                current_num = sum(1 for _ in ks)
                acc_num += current_num
                pair += (len(trues) - acc_num) * current_num
    
            return pair
    
    if __name__ == '__main__':
        right = 0.
        wrong = 0.
        for user, lines in groupby(sys.stdin, key=lambda x:x.split('\t')[0]):
            lines = list(lines)
            #print lines
            trues = [float(x.strip().split('\t')[1]) for x in lines]
            preds = [float(x.strip().split('\t')[2]) for x in lines]
            right += InversionCounter.merge_sort_count_strict_right(trues, preds)
            wrong += InversionCounter.merge_sort_count_strict_wrong(trues, preds)
        print (right, wrong, right / wrong)
  • 相关阅读:
    mysql 中将汉字(中文)按照拼音首字母排序
    数据库连接客户端 dbeaver 程序包以及使用说明
    maven 项目在 tomcat 中启动报错:Caused by: java.util.zip.ZipException: invalid LOC header (bad signature)
    iPadOS 更新日志
    iOS 更新日志
    mybatis 中 if else 用法
    Chrome 地址栏如何设置显示 http/https 和 www
    Windows 常用工具 & 开发工具 & Chrome插件 & Firefox 插件 & 办公软件
    elasticsearch安装ik分词器
    js关闭浏览器
  • 原文地址:https://www.cnblogs.com/zle1992/p/16263169.html
Copyright © 2020-2023  润新知