数据来源《统计学习方法》(李航):
x=[ 0,1,2,3,4,5,6,7,8,9]
y=[1,1,1, -1,-1,-1, 1,1,1,-1];
实现:
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Sun Oct 14 18:52:18 2018 4 @author: Administrator 5 """ 6 #处理并获取数据 7 import pandas as pd 8 import math 9 import numpy as np 10 11 def getData(): 12 x=[ 0,1,2,3,4,5,6,7,8,9] 13 y=[1,1,1, -1,-1,-1, 1,1,1,-1]; 14 z=[] 15 for i in range(0,len(x)): 16 z.append( [ x[i],y[i] ] ); 17 data=pd.DataFrame(z,columns=['x','y']); 18 return data 19 20 #基分类器 21 def basicClassifier(Dweight,data): 22 #采取x<v 23 #阈值选择使分类误差最小的 24 threshold=0; #阈值 25 error=1; #分类误差率 26 label=1; # x<v 为正(负)类, 27 for i in range( 0,len(data) ): 28 tmpThreshold=data['x'][i]; #选择阈值 29 tmpError1=0 #判断 label 为正类的误差 30 tmpError2=0 #判断 label 为负类的误差 31 for j in range( 0,len(data) ): 32 if( data['x'][j]<tmpThreshold ): 33 if( data['y'][j]==-1 ): 34 tmpError1+=Dweight[j];# 如果label为正类 35 else: 36 tmpError2+=Dweight[j];# 如果label为负类 37 else: 38 if( data['y'][j]==1 ): 39 tmpError1+=Dweight[j];# label为正类 40 else: 41 tmpError2+=Dweight[j];# label为负类 42 if( error>tmpError1 and tmpError1!=0 and tmpError1<tmpError2): 43 threshold=tmpThreshold; 44 error=tmpError1; 45 label=1 46 if( error>tmpError2 and tmpError1>tmpError2 and tmpError2!=0 ): 47 threshold=tmpThreshold; 48 label=-1; 49 error=tmpError2; 50 #求该基本分类器的权重 51 alpha = math.log((1-error)/error)/2 ; 52 #更新数据权重 53 NewDweight=[] 54 sumAll=0 #规范化因子 55 for i in range(0,len(data)): 56 if( data['x'][i]< threshold ): 57 Gm=label 58 else: 59 Gm=label*(-1) 60 vv=Dweight[i]*( math.e**(alpha*data['y'][i]*Gm*(-1)) ) 61 NewDweight.append(vv); 62 sumAll+=vv; 63 NewDweight = np.array(NewDweight); 64 NewDweight = NewDweight/sumAll; 65 return threshold,label,alpha,list(NewDweight); 66 67 #最后的分类器 68 def adaboost(data): 69 threshold=[]; #每个基本分类器的阈值 70 label=[]; #每个基本分类器 x<v 类别 71 Cweight=[]; #每个基本分类器的权重 72 Dweight=[]; #每项数据权重 73 for i in range(0,len(data)): 74 Dweight.append(1/len(data)); 75 M=3; # M个基本分类器组合 76 for i in range(0,M): 77 th, la, Cw ,Dweight= basicClassifier(Dweight,data) 78 threshold.append(th) 79 label.append(la) 80 Cweight.append(Cw) 81 return [ threshold, label,Cweight ] 82 83 def classify(data, model): 84 for i in range(0,len(data)): 85 val=0; 86 for j in range(0,len(model[0])): 87 if( data['x'][i]< model[0][j] ): 88 val+=( model[1][j]*model[2][j] ); 89 else: 90 val+=( model[1][j]*model[2][j]*(-1) ); 91 if(val<0): 92 print('-1') 93 else: 94 print('1'); 95 data = getData(); 96 mo=adaboost(data); 97 classify(data,mo)