1、利用 auc
#encoding=utf8 from itertools import groupby import sys def calc_auc_and_pnr_fast(label,pred): sample = zip(label,pred) ## 根据pred倒排 sample_sorted = sorted(sample,key=lambda x: -x[1]) pos = 0 cnt = 0 r_cnt = 0 last_pred = 0 for i in range(len(sample_sorted)): l, p = sample_sorted[i] if l == 1: pos += 1 elif l == 0: cnt += pos # 截止目前,有pos个正样本比他大 if (i != 0 and last_pred == p): cnt -= 0.5 last_pred = p n = len(label) negs = n - pos r_cnt = pos * negs - cnt auc = float(cnt) / float(pos * negs) pnr = float(cnt) / r_cnt return auc, pnr if __name__ == '__main__': for user, lines in groupby(sys.stdin, key=lambda x:x.split('\t')[0]): lines = list(lines) #print lines trues = [float(x.strip().split('\t')[1]) for x in lines] preds = [float(x.strip().split('\t')[2]) for x in lines] auc, pnr = calc_auc_and_pnr_fast(trues, preds) ## auc = 1/(1+1/pnr) ==> pnr = 1/ (1/a - 1) pnr_check = 1. / (1. / auc - 1 + 1e-9) print auc, pnr, pnr_check
2、归并
https://www.jianshu.com/p/e9813ac25cb6
""" inversecount """ from itertools import groupby import sys class InversionCounter(object): """ InversionCounter """ @classmethod def merge_sort_count_sub(cls, vals): """ merge_sort_count_sub """ if sys.version > '3': if len(list(vals)) <= 1: return vals, 0 else: if len(vals) <= 1: return vals, 0 n = len(vals) left_vals, left_cnt = cls.merge_sort_count_sub(vals[:n / 2]) right_vals, right_cnt = cls.merge_sort_count_sub(vals[n / 2:]) left_i = 0 right_i = 0 mid_cnt = 0 new_vals = [] while True: if left_vals[left_i][1] <= right_vals[right_i][1]: new_vals.append(left_vals[left_i]) left_i += 1 elif left_vals[left_i][1] > right_vals[right_i][1]: mid_cnt += (len(left_vals) - left_i) new_vals.append(right_vals[right_i]) right_i += 1 if left_i == len(left_vals): new_vals.extend(right_vals[right_i:]) break if right_i == len(right_vals): new_vals.extend(left_vals[left_i:]) break return new_vals, left_cnt + mid_cnt + right_cnt @classmethod def merge_sort_count_strict_right(cls, trues, preds): """ merge_sort_count_strict_right """ neg_preds = (-p for p in preds) vals = zip(trues, neg_preds) if sys.version > '3': sorted(vals) else : vals.sort() return cls.merge_sort_count_sub(vals)[1] @classmethod def merge_sort_count_strict_wrong(cls, trues, preds): """ merge_sort_count_strict_wrong """ vals = zip(trues, preds) if sys.version > '3': sorted(vals) else : vals.sort() return cls.merge_sort_count_sub(vals)[1] @classmethod def merge_sort_count_right(cls, trues, preds): """ merge_sort_count_right """ return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_wrong(trues, preds) @classmethod def merge_sort_count_wrong(cls, trues, preds): """ merge_sort_count_wrong """ return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_right(trues, preds) @classmethod def merge_sort_count_pair(cls, trues, preds=None): """ preds: dummpy variable, no need inside function """ trues = sorted(trues) acc_num = 0 pair = 0 for k, ks in groupby(trues): current_num = sum(1 for _ in ks) acc_num += current_num pair += (len(trues) - acc_num) * current_num return pair if __name__ == '__main__': right = 0. wrong = 0. for user, lines in groupby(sys.stdin, key=lambda x:x.split('\t')[0]): lines = list(lines) #print lines trues = [float(x.strip().split('\t')[1]) for x in lines] preds = [float(x.strip().split('\t')[2]) for x in lines] right += InversionCounter.merge_sort_count_strict_right(trues, preds) wrong += InversionCounter.merge_sort_count_strict_wrong(trues, preds) print (right, wrong, right / wrong)