有时我们在实际分类数据挖掘中经常会遇到,类别样本很不均衡,直接使用这种不均衡数据会影响一些模型的分类效果,如logistic regression,SVM等,一种解决办法就是对数据进行均衡采样,这里就提供了一个建议代码实现,要求输入和输出数据格式为Label+Tab+Features, 如Libsvm format
-1 1:0.875 2:-1 3:-0.333333 4:-0.509434 5:-0.347032 6:-1 7:1 8:-0.236641 9:1 10:-0.935484 11:-1 12:-0.333333 13:-1 +1 1:0.166667 2:1 3:-0.333333 4:-0.433962 5:-0.383562 6:-1 7:-1 8:0.0687023 9:-1 10:-0.903226 11:-1 12:-1 13:1 +1 1:0.708333 2:1 3:1 4:-0.320755 5:-0.105023 6:-1 7:1 8:-0.419847 9:-1 10:-0.225806 12:1 13:-1 -1 1:0.583333 2:-1 3:0.333333 4:-0.603774 5:1 6:-1 7:1 8:0.358779 9:-1 10:-0.483871 12:-1 13:1
用法 Usage:
Usage: {0} [options] dataset subclass_size [output] options: -s method : method of selection (default 0) 0 -- over-sampling & under-sampling given subclass_size 1 -- over-sampling (subclass_size: any value) 2 -- under-sampling(subclass_size: any value)
Bash example:
python SampleDataset.py -s 0 heart_scale 20 heart_scale.txt
这里s参数表示抽样的方法,
-s 0:Over sampling &Under sampling ,即对类别多的进行降采样,对类别少的进行重采样
-s 1: Over sampling 对类别少的进行重采样,采样后的每类样本数与最多的那一类一致
-s 2:Under sampling 对类别多的进行降采样,采样后的每类样本数与最少的那一类一值
输入数据文件heart_scale
输出数据文件heart_scale.txt
下面是代码文件:SampleDataset.py:
#!/usr/bin/env python from sklearn.datasets import load_svmlight_file from sklearn.datasets import dump_svmlight_file import numpy as np from sklearn.utils import check_random_state from scipy.sparse import hstack,vstack import os, sys, math, random from collections import defaultdict if sys.version_info[0] >= 3: xrange = range def exit_with_help(argv): print(""" Usage: {0} [options] dataset subclass_size [output] options: -s method : method of selection (default 0) 0 -- over-sampling & under-sampling given subclass_size 1 -- over-sampling (subclass_size: any value) 2 -- under-sampling(subclass_size: any value) output : balance set file (optional) If output is omitted, the subset will be printed on the screen.""".format(argv[0])) exit(1) def process_options(argv): argc = len(argv) if argc < 3: exit_with_help(argv) # default method is over-sampling & under-sampling method = 0 BalanceSet_file = sys.stdout i = 1 while i < argc: if argv[i][0] != "-": break if argv[i] == "-s": i = i + 1 method = int(argv[i]) if method not in [0,1,2]: print("Unknown selection method {0}".format(method)) exit_with_help(argv) i = i + 1 dataset = argv[i] BalanceSet_size = int(argv[i+1]) if i+2 < argc: BalanceSet_file = open(argv[i+2],'w') return dataset, BalanceSet_size, method, BalanceSet_file def stratified_selection(dataset, subset_size, method): labels = [line.split(None,1)[0] for line in open(dataset)] label_linenums = defaultdict(list) for i, label in enumerate(labels): label_linenums[label] += [i] l = len(labels) remaining = subset_size ret = [] # classes with fewer data are sampled first; label_list = sorted(label_linenums, key=lambda x: len(label_linenums[x])) min_class = label_list[0] maj_class = label_list[-1] min_class_num = len(label_linenums[min_class]) maj_class_num = len(label_linenums[maj_class]) random_state = check_random_state(42) for label in label_list: linenums = label_linenums[label] label_size = len(linenums) if method == 0: if label_size<subset_size: ret += linenums subnum = subset_size-label_size else: subnum = subset_size ret += [linenums[i] for i in random_state.randint(low=0, high=label_size,size=subnum)] elif method == 1: if label == maj_class: ret += linenums continue else: ret += linenums subnum = maj_class_num-label_size ret += [linenums[i] for i in random_state.randint(low=0, high=label_size,size=subnum)] elif method == 2: if label == min_class: ret += linenums continue else: subnum = min_class_num ret += [linenums[i] for i in random_state.randint(low=0, high=label_size,size=subnum)] random.shuffle(ret) return ret def main(argv=sys.argv): dataset, subset_size, method, subset_file = process_options(argv) selected_lines = [] selected_lines = stratified_selection(dataset, subset_size,method) #select instances based on selected_lines dataset = open(dataset,'r') datalist = dataset.readlines() for i in selected_lines: subset_file.write(datalist[i]) subset_file.close() dataset.close() if __name__ == '__main__': main(sys.argv)