▶ 使用逻辑地模型来进行多分类,采用了 one v.s. other 的方式训练了 k 个分类器(k 为类别数),然后选择独类分类概率最高的作为最终结果
● 代码,向下兼容二分类,计算量变大了
1 import numpy as np 2 import matplotlib.pyplot as plt 3 from mpl_toolkits.mplot3d import Axes3D 4 from mpl_toolkits.mplot3d.art3d import Poly3DCollection 5 from matplotlib.patches import Rectangle 6 7 dataSize = 10000 8 trainRatio = 0.3 9 ita = 0.05 10 epsilon = 0.01 11 defaultTurn = 200 12 trans = 0.5 13 14 def myColor(x): # 颜色函数,用于对散点染色 15 r = np.select([x < 1/2, x < 3/4, x <= 1, True],[0, 4 * x - 2, 1, 0]) 16 g = np.select([x < 1/4, x < 3/4, x <= 1, True],[4 * x, 1, 4 - 4 * x, 0]) 17 b = np.select([x < 1/4, x < 1/2, x <= 1, True],[1, 2 - 4 * x, 0, 0]) 18 return [r**2,g**2,b**2] 19 20 def sigmoid(x): 21 return 1.0 / (1 + np.exp(-x)) 22 23 def function(x, para): # 回归函数 24 vector = np.array( [ np.exp( - np.sum(x * para[0][i]) - para[1][i]) for i in range(len(para[0])) ]) 25 return vector #return vector / np.sum(vector) 26 27 def judge(x, para): # 分类函数 28 return np.argmin(function(x, para)) 29 30 def dataSplit(x, y, part): 31 return x[:part], y[:part],x[part:],y[part:] 32 33 def createData(dim, kind, count = dataSize): # 创建数据集 34 np.random.seed(103) 35 X = np.random.rand(count, dim) 36 Y = ((3 - 2 * dim)*X[:,0] + 2 * np.sum(X[:,1:], 1) > 0.5).astype(int) # 只考虑 {0,1} 的二分类 37 if kind == 2: 38 Y = ((3 - 2 * dim) * X[:,0] + 2 * np.sum(X[:,1:], 1) > 0.5).astype(int) 39 else: 40 randomVector = np.random.rand(dim) 41 randomVector /= np.sum(randomVector) 42 Y = (np.sum(X * randomVector,1) * kind).astype(int) 43 print("dim = %d, kind = %d, dataSize = %d"%(dim, kind, count)) 44 kindCount = np.zeros(kind ,dtype = int) # 各类别的占比 45 for i in range(count): 46 kindCount[Y[i]] += 1 47 for i in range(kind): 48 print("kind %d -> %4f"%(i, kindCount[i]/count)) 49 return X, Y 50 51 def gradientDescent(dataX, dataY, turn = defaultTurn): 52 count, dim = np.shape(dataX) 53 kind = len(set(dataY)) 54 xE = np.concatenate((dataX, np.ones(count)[:,np.newaxis]), axis = 1) 55 w = np.ones([kind, dim + 1]) 56 57 for t in range(turn): 58 errorCount = 0 59 for i in range(count): 60 for j in range(kind): 61 error = int(j == dataY[i]) - sigmoid( np.sum(xE[i] * w[j]) ) # dataYi 类当成 1 号类,其他类当成 0 号类,error = yReal - yPredict 62 w[j] += ita * error * xE[i] 63 errorCount += int(abs(error) > 0.5) 64 print(w) 65 if errorCount < count * epsilon: 66 break 67 68 resultOnTrainData = [ judge(x, (w[:,:-1], w[:,-1])) for x in dataX] 69 errorRatioOnTrainData = np.sum( ((np.array(resultOnTrainData) != dataY)).astype(int)**2 ) / count 70 print("errorRatioOnTrainData = %4f "%(errorRatioOnTrainData)) 71 return (w[:,:-1], w[:,-1]) 72 73 def test(dim, kind): 74 allX, allY = createData(dim, kind) 75 trainX, trainY, testX, testY = dataSplit(allX, allY, int(dataSize * trainRatio)) 76 77 para = gradientDescent(trainX, trainY) # 训练 78 79 myResult = [ judge(x, para) for x in testX] 80 errorRatio = np.sum( ((np.array(myResult) != testY)).astype(int)**2 ) / (dataSize * (1 - trainRatio)) 81 print("dim = %d, errorRatio = %4f "%(dim, errorRatio)) 82 83 if dim >= 4: # 4维以上不画图,只输出测试错误率 84 return 85 errorP = [] 86 classP = [ [] for i in range(kind) ] 87 for i in range(len(testX)): 88 if myResult[i] != testY[i]: 89 if dim == 1: 90 errorP.append(np.array([testX[i], testY[i]])) 91 else: 92 errorP.append(np.array(testX[i])) 93 else: 94 classP[myResult[i]].append(testX[i]) 95 errorP = np.array(errorP) 96 classP = [ np.array(classP[i]) for i in range(kind) ] 97 98 fig = plt.figure(figsize=(10, 8)) 99 100 if dim == 1: 101 plt.xlim(-0.1, 1.1) 102 plt.ylim(-0.1, 1.1) 103 for i in range(kind): 104 plt.scatter(classP[i], np.ones(len(classP[i])) * i / (kind-1), color = myColor(i / kind), s = 2, label = "class" + str(i) + "Data") 105 if len(errorP) != 0: 106 plt.scatter(errorP[:,0], errorP[:,1], color = myColor(1), s = 16, label = "errorData") 107 R = [ Rectangle((0,0),0,0, color = myColor(i / kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ] 108 plt.legend(R, [ "class" + str(i) for i in range(kind) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1) 109 110 if dim == 2: 111 plt.xlim(-0.1, 1.1) 112 plt.ylim(-0.1, 1.1) 113 for i in range(kind): 114 plt.scatter(classP[i][:,0], classP[i][:,1], color = myColor(i/kind), s = 8, label = "class" + str(i)) 115 if len(errorP) != 0: 116 plt.scatter(errorP[:,0], errorP[:,1], color = myColor(1), s = 16, label = "errorData") 117 R = [ Rectangle((0,0),0,0, color = myColor(i/kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ] 118 plt.legend(R, [ "class" + str(i) for i in range(kind) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1) 119 120 if dim == 3: 121 ax = Axes3D(fig) 122 ax.set_xlim3d(-0.1, 1.1) 123 ax.set_ylim3d(-0.1, 1.1) 124 ax.set_zlim3d(-0.1, 1.1) 125 ax.set_xlabel('X', fontdict={'size': 15, 'color': 'k'}) 126 ax.set_ylabel('Y', fontdict={'size': 15, 'color': 'k'}) 127 ax.set_zlabel('Z', fontdict={'size': 15, 'color': 'k'}) 128 #v = [(0, 0, 0.25), (0, 0.25, 0), (0.5, 1, 0), (1, 1, 0.75), (1, 0.75, 1), (0.5, 0, 1)] 129 #f = [[0,1,2,3,4,5]] 130 #poly3d = [[v[i] for i in j] for j in f] 131 #ax.add_collection3d(Poly3DCollection(poly3d, edgecolor = 'k', facecolors = [0.5,0.25,0.0,0.5], linewidths=1)) 132 for i in range(kind): 133 ax.scatter(classP[i][:,0], classP[i][:,1],classP[i][:,2], color = myColor(i/kind), s = 8, label = "class" + str(i)) 134 if len(errorP) != 0: 135 ax.scatter(errorP[:,0], errorP[:,1],errorP[:,2], color = myColor(1), s = 16, label = "errorData") 136 R = [ Rectangle((0,0),0,0, color = myColor(i/kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ] 137 plt.legend(R, [ "class" + str(i) for i in range(kind) ] + ["errorData"], loc=[0.85, 0.02], ncol=1, numpoints=1, framealpha = 1) 138 139 fig.savefig("R:\dim" + str(dim) + "kind" + str(kind) + ".png") 140 plt.close() 141 142 if __name__=='__main__': 143 test(1, 2) 144 test(2, 2) 145 test(3, 2) 146 test(4, 2) 147 test(5, 2) 148 149 test(1, 3) 150 test(2, 3) 151 test(2, 4) 152 test(3, 3) 153 test(3, 4) 154 test(4, 4) 155 test(5, 6)
● 输出结果
dim = 1, kind = 2, dataSize = 10000 kind 0 -> 0.491000 kind 1 -> 0.509000 [[-6.71486872 3.25224943] [ 6.90100937 -3.34871095]] [[-9.4024679 4.63658391] [ 9.51290414 -4.69311656]] [[-11.18847685 5.54813528] [ 11.26986341 -5.58954968]] [[-12.5673646 6.2486029 ] [ 12.63308806 -6.2819321 ]] [[-13.70905382 6.82693293] [ 13.76484542 -6.85516245]] [[-14.69334986 7.32458602] [ 14.7422152 -7.34927213]] [[-15.56456028 7.76445884] [ 15.60828366 -7.78652142]] [[-16.35002989 8.1606272 ] [ 16.38976341 -8.18065822]] [[-17.06791905 8.52241138] [ 17.10445313 -8.54081596]] errorRatioOnTrainData = 0.000000 dim = 1, errorRatio = 0.000857 dim = 2, kind = 2, dataSize = 10000 kind 0 -> 0.504000 kind 1 -> 0.496000 [[ 3.13914102 -6.57280315 1.55975592] [-3.0336306 6.74084399 -1.70421297]] [[ 4.299232 -9.1666085 2.28471823] [-4.2828194 9.25105756 -2.33797473]] ... [[ 9.06732636 -18.39799714 4.5422244 ] [ -9.07689803 18.41643909 -4.54667908]] [[ 9.34598742 -18.9326538 4.67067573] [ -9.35504382 18.95010089 -4.67488586]] errorRatioOnTrainData = 0.008333 dim = 2, errorRatio = 0.006286 dim = 3, kind = 2, dataSize = 10000 kind 0 -> 0.501800 kind 1 -> 0.498200 [[ 5.68320585 -3.41622995 -3.29229634 0.291568 ] [-5.56080847 3.5668759 3.4417622 -0.53161699]] [[ 7.73185326 -4.86686995 -4.78136275 0.72516127] [-7.70673669 4.93636661 4.84909001 -0.8170064 ]] ... [[ 20.80913419 -13.92078602 -14.0109533 3.30945084] [-20.81629159 13.92562185 14.01584334 -3.31075675]] [[ 21.06744447 -14.09526106 -14.18737273 3.35653806] [-21.0744349 14.09998155 14.19214546 -3.35781137]] errorRatioOnTrainData = 0.011333 dim = 3, errorRatio = 0.011429 dim = 4, kind = 2, dataSize = 10000 kind 0 -> 0.503100 kind 1 -> 0.496900 [[ 6.39482357 -2.22198617 -2.18901346 -2.1453085 -0.11805323] [-6.28885654 2.39260481 2.35279566 2.30920127 -0.2083632 ]] [[ 8.75113897 -3.18484341 -3.20304568 -3.10944871 0.2151458 ] [-8.72994808 3.26770859 3.2833056 3.18813441 -0.35640864]] ... [[ 23.98904034 -9.39850863 -9.62271368 -9.41844857 2.13802609] [-23.99935695 9.40265978 9.62690629 9.42265612 -2.13912499]] [[ 24.27733048 -9.51450089 -9.73983514 -9.53600906 2.16868557] [-24.28741351 9.51855807 9.74393065 9.54012085 -2.16975711]] errorRatioOnTrainData = 0.003000 dim = 4, errorRatio = 0.004000 dim = 5, kind = 2, dataSize = 10000 kind 0 -> 0.500000 kind 1 -> 0.500000 [[ 6.89705474 -1.42540518 -1.41940664 -1.48996056 -1.32395489 -0.50462332] [-6.75241758 1.58499742 1.59950655 1.65460537 1.49287702 0.09657823]] [[ 9.40073388 -2.10821137 -2.14136888 -2.16572326 -2.01953875 -0.37016685] [-9.34976596 2.19289023 2.23746207 2.24904918 2.10980714 0.16772995]] ... [[ 35.23149069 -9.72293551 -10.03130329 -9.42816884 -9.61618479 1.79576885] [-35.23648585 9.72438093 10.03278997 9.42956334 9.61760598 -1.7961253 ]] [[ 35.39225483 -9.76945223 -10.0791467 -9.47305001 -9.66192163 1.80723642] [-35.3972061 9.77088478 10.08062007 9.4744323 9.66333016 -1.80758947]] errorRatioOnTrainData = 0.003667 dim = 5, errorRatio = 0.005714 dim = 1, kind = 3, dataSize = 10000 kind 0 -> 0.321300 kind 1 -> 0.344100 kind 2 -> 0.334600 [[-6.66135149 1.88485463] [-0.02324545 -0.47039097] [ 6.1365322 -4.22324277]] [[-9.42188847 2.89492636] [-0.04642244 -0.45771294] [ 8.62181647 -5.85759954]] ... [[-5.10898588e+01 1.69452874e+01] [-4.93172361e-02 -4.56130326e-01] [ 4.72560247e+01 -3.16276995e+01]] [[-5.11769199e+01 1.69744423e+01] [-4.93172361e-02 -4.56130326e-01] [ 4.73360203e+01 -3.16809702e+01]] errorRatioOnTrainData = 0.014333 dim = 1, errorRatio = 0.014714 dim = 2, kind = 3, dataSize = 10000 kind 0 -> 0.227200 kind 1 -> 0.530300 kind 2 -> 0.242500 [[-5.00676085 -2.45044419 1.97399071] [ 0.22096798 -0.01350022 0.13430097] [ 4.20653754 2.23601111 -4.88443453]] [[-7.28718161 -3.77134478 3.32839949] [ 0.18256302 -0.06175728 0.17951199] [ 6.15014501 3.44052674 -6.86263964]] ... [[-42.98370113 -23.84773265 22.1558293 ] [ 0.17536418 -0.06894655 0.18704851] [ 38.02649615 21.76101314 -39.93451349]] [[-43.05806049 -23.8893055 22.19431298] [ 0.17536418 -0.06894655 0.18704851] [ 38.09248731 21.79835741 -40.00352016]] errorRatioOnTrainData = 0.007667 dim = 2, errorRatio = 0.015143 dim = 2, kind = 4, dataSize = 10000 kind 0 -> 0.126800 kind 1 -> 0.364700 kind 2 -> 0.372200 kind 3 -> 0.136300 [[-3.98654929 -2.59558945 0.63871501] [-2.98919471 -0.69136057 1.08710997] [ 3.07843511 0.64949316 -2.34760386] [ 2.97827188 1.97871338 -4.597351 ]] [[-6.01515149 -3.94558636 1.76243622] [-3.5549713 -0.991616 1.49906178] [ 3.66384428 0.91725841 -2.82881226] [ 4.65164848 3.12899868 -6.4712698 ]] ... [[-39.81632455 -22.97514101 15.60182524] [ -3.75722716 -1.11931026 1.6551282 ] [ 3.93424393 1.06984973 -3.06739184] [ 32.71042376 19.01196407 -38.71261752]] [[-39.88439704 -23.01307333 15.62864229] [ -3.75722716 -1.11931026 1.6551282 ] [ 3.93424393 1.06984973 -3.06739184] [ 32.7681305 19.04391614 -38.77955666]] errorRatioOnTrainData = 0.102000 dim = 2, errorRatio = 0.104429 dim = 3, kind = 3, dataSize = 10000 kind 0 -> 0.170600 kind 1 -> 0.651200 kind 2 -> 0.178200 [[-2.80037838 -1.85733668 -3.24123098 1.57259022] [-0.16142963 0.24714771 -0.19162875 0.6026281 ] [ 2.58163287 1.19894974 2.85973347 -4.87914069]] [[-4.24260115 -2.85160253 -4.82350704 3.1797816 ] [-0.27409502 0.16345042 -0.29640289 0.77147737] [ 3.84381624 2.04689038 4.19587946 -7.05644654]] ... [[-28.0745761 -19.73079916 -31.09831116 26.15657118] [ -0.30475613 0.13589837 -0.32360268 0.81945196] [ 23.68723658 16.08356041 26.22820206 -43.78315189]] [[-28.1238723 -19.76537271 -31.15279228 26.20290789] [ -0.30475613 0.13589837 -0.32360268 0.81945196] [ 23.72821623 16.11223028 26.27366977 -43.85969012]] errorRatioOnTrainData = 0.023333 dim = 3, errorRatio = 0.024286 dim = 3, kind = 4, dataSize = 10000 kind 0 -> 0.067900 kind 1 -> 0.429700 kind 2 -> 0.426400 kind 3 -> 0.076000 [[-2.00252838 -1.72767335 -2.24466305 -0.5296669 ] [-2.10120768 -1.22514198 -2.89788163 2.66060746] [ 1.95193431 1.34684809 2.75742256 -3.30044301] [ 1.38776371 0.68175952 1.46934361 -4.12383637]] [[-3.07733838 -2.51564061 -3.4173077 0.63571256] [-2.85622162 -1.79855655 -3.71498848 3.76282969] [ 2.56857273 1.85537708 3.45570769 -4.31289162] [ 2.39085453 1.41231651 2.50636325 -5.890961 ]] ... [[-23.5413656 -16.6567255 -25.3634883 16.160629 ] [ -3.59081061 -2.36125037 -4.49314391 4.81119705] [ 3.19014428 2.37566951 4.13596461 -5.3188518 ] [ 18.54073398 13.1811811 19.81375414 -38.83427155]] [[-23.58185249 -16.68489085 -25.40812839 16.18962668] [ -3.59081061 -2.36125037 -4.49314391 4.81119705] [ 3.19014428 2.37566951 4.13596461 -5.3188518 ] [ 18.57306601 13.20422866 19.84937411 -38.90186847]] errorRatioOnTrainData = 0.086000 dim = 3, errorRatio = 0.097429 dim = 4, kind = 4, dataSize = 10000 kind 0 -> 0.062600 kind 1 -> 0.428800 kind 2 -> 0.437600 kind 3 -> 0.071000 [[-0.71632006 -2.59391172 -1.87113508 -0.80861663 -0.56502796] [-0.11663555 -3.59928303 -1.15899786 -0.70171 2.44850899] [ 0.37202371 3.49748029 1.09665309 0.59203424 -3.0682648 ] [-0.12479086 1.7385499 1.16253629 0.31590619 -3.95319091]] [[-0.9326788 -4.01402554 -2.80183734 -1.15387173 0.54926063] [-0.4657425 -4.55847041 -1.71715683 -1.14886744 3.620152 ] [ 0.6652545 4.29734591 1.55022533 0.91784882 -4.0629789 ] [ 0.20215012 2.88925454 2.05115192 0.76774619 -5.70430072]] ... [[ -7.80293379 -32.00422274 -18.32788029 -9.67368173 16.64301744] [ -0.87738729 -5.56011488 -2.32996241 -1.63780016 4.88436562] [ 1.00962029 5.08689554 2.05662977 1.29336137 -5.13507463] [ 5.79210512 23.64914995 15.05904647 8.41324924 -39.64004545]] [[ -7.81695093 -32.0618173 -18.35982709 -9.69111765 16.67358517] [ -0.87738729 -5.56011488 -2.32996241 -1.63780016 4.88436562] [ 1.00962029 5.08689554 2.05662977 1.29336137 -5.13507463] [ 5.80255199 23.69329045 15.08405843 8.42725548 -39.7101195 ]] errorRatioOnTrainData = 0.118333 dim = 4, errorRatio = 0.111143 dim = 5, kind = 6, dataSize = 10000 kind 0 -> 0.005500 kind 1 -> 0.106600 kind 2 -> 0.374800 kind 3 -> 0.391000 kind 4 -> 0.118300 kind 5 -> 0.003800 [[-0.93151419 -1.15264489 -1.09286506 -1.07529811 -1.0044742 -2.86144712] [-0.42948115 -2.20080396 -1.46486443 -1.97822921 -0.81411715 1.04242008] [-0.05671104 -1.7655673 -1.16561097 -1.21070352 -0.62971664 1.73063998] [ 0.06718988 1.69013898 1.22077026 1.1620007 0.39884712 -2.95602739] [ 0.17581079 1.70913602 1.03362367 1.50272166 0.70368526 -4.39604687] [-0.7302654 -0.79562672 -0.71042569 -0.69781866 -0.78849863 -3.11233474]] [[-1.03026903 -1.35431366 -1.30340961 -1.27357011 -1.10123572 -2.8378497 ] [-0.81491271 -3.34293023 -2.30346758 -3.01864634 -1.40046136 2.63998911] [-0.31184047 -2.24459839 -1.59020725 -1.63260382 -0.97931141 2.64284953] [ 0.26039654 2.10258419 1.58891349 1.51831021 0.64976697 -3.7885964 ] [ 0.53260501 2.68629705 1.78230865 2.36776549 1.28603549 -6.49030701] [-0.68845101 -0.70282747 -0.60406175 -0.5841734 -0.73161019 -3.46754385]] ... [[ -4.64430182 -9.88889103 -10.32735806 -9.71878585 -5.75540649 6.03823936] [ -4.85415656 -14.00558042 -10.59404533 -12.56228907 -7.56849661 16.2053294 ] [ -0.52624566 -2.60362638 -1.91668611 -1.96394575 -1.259864 3.34867158] [ 0.45294325 2.43597606 1.89645053 1.82899407 0.87841299 -4.5130601 ] [ 4.66907859 14.56942264 10.74890854 12.60299926 7.68186565-33.74429057] [ 2.95492003 5.2915096 6.37632582 5.97786528 3.9272511 -21.7415685 ]] [[ -4.65279681 -9.90877286 -10.34768486 -9.7390159 -5.7680122 6.05417403] [ -4.85537082 -14.00906197 -10.59674944 -12.56542149 -7.57040832 16.20950731] [ -0.52624566 -2.60362638 -1.91668611 -1.96394575 -1.259864 3.34867158] [ 0.45294325 2.43597606 1.89645053 1.82899407 0.87841299 -4.5130601 ] [ 4.67283876 14.58138225 10.75770271 12.61334672 7.68803566-33.77157783] [ 2.96057027 5.30377548 6.39212882 5.9914551 3.93604998-21.78509134]] errorRatioOnTrainData = 0.097667 dim = 5, errorRatio = 0.106429
● 画图(一维)
● 画图(二维)
● 画图(三维)