描述:李航《统计学习方法》第二章感知机算法实现(Python)
原始形式:
1 # _*_ encoding:utf-8 _*_ 2 3 import numpy as np 4 import matplotlib.pyplot as plt 5 6 7 def createdata(): 8 """创建数据集和相应类标记""" 9 samples = np.array([[3, 3], [4, 3], [1, 1]]) 10 labels = np.array([1, 1, -1]) 11 return samples, labels 12 13 14 15 class Perceptron: 16 """感知机模型""" 17 18 def __init__(self, x, y, a=1): 19 self.x = x 20 self.y = y 21 self.w = np.zeros((x.shape[1], 1)) 22 self.b = 0 23 self.a = 1 #学习率 24 self.numsamples = self.x.shape[0] 25 self.numfeatures = self.x.shape[1] 26 27 def sign(self, w, b, x): 28 """计算某样本点的f(x)""" 29 y = np.dot(x, w) + b 30 return int(y) 31 32 def update(self, label_i, data_i): 33 """更新w和b""" 34 tmp = label_i * self.a * data_i 35 tmp = tmp.reshape(self.w.shape) 36 self.w = tmp + self.w 37 self.b = self.b + label_i * self.a 38 39 def train(self): 40 """训练感知机模型""" 41 isfind = False 42 while not isfind: 43 count = 0 44 for i in range(self.numsamples): 45 tmp = self.sign(self.w, self.b, self.x[i, :]) 46 if tmp * self.y[i] <= 0: 47 print('误分类点为: ', self.x[i, :], '此时的w和b为: ', self.w, self.b) 48 count += 1 49 self.update(self.y[i], self.x[i, :]) 50 if count == 0: 51 print('最终训练得到的w和b为: ', self.w, self.b) 52 isfind = True 53 return self.w, self.b 54 55 56 57 class Picture: 58 """数据可视化""" 59 60 def __init__(self, data, w, b): 61 """初始化参数""" 62 self.b = b 63 self.w = w 64 plt.figure(1) 65 plt.title('Perceptron Learning Algorithm', size= 14) 66 plt.xlabel('x0-axis', size=14) 67 plt.ylabel('x1-axis', size=14) 68 69 xData = np.linspace(0, 5, 100) 70 yData = self.expression(xData) 71 plt.plot(xData, yData, color='r', label='sample data') 72 73 plt.scatter(data[0][0], data[0][1], c='r', s=50) 74 plt.scatter(data[1][0], data[1][1], c='g', s=50) 75 plt.scatter(data[2][0], data[2][1], s=50, c='b', marker='x') 76 77 plt.savefig('original.png', dpi=75) 78 79 def expression(self, x): 80 """计算超平面上对应的纵坐标""" 81 y = (-self.b - self.w[0] * x) / self.w[1] 82 return y 83 84 def show(self): 85 """画图""" 86 plt.show() 87 88 89 if __name__ == '__main__': 90 samples, labels = createdata() 91 myperceptron = Perceptron(samples, labels) 92 weights, bias = myperceptron.train() 93 picture = Picture(samples, weights, bias) 94 picture.show()
对偶形式:
1 # _*_ encoding:utf-8 _*_ 2 3 import numpy as np 4 import matplotlib.pyplot as plt 5 6 def createdata(): 7 """创建数据集和相应的类标记""" 8 samples = np.array([[3, 3], [4, 3], [1, 1]]) 9 labels = np.array([1, 1, -1]) 10 return samples, labels 11 12 13 class Perceptron: 14 """感知机模型""" 15 16 def __init__(self, x, y, a=1): 17 """初始化数据集,标记,学习率,参数等""" 18 self.x = x 19 self.y = y 20 self.w = np.zeros((1, x.shape[0])) 21 self.b = 0 22 self.a = a 23 self.numsamples = self.x.shape[0] 24 self.numfeatures = self.x.shape[1] 25 self.gmatrix = self.gMatrix() 26 27 def gMatrix(self): 28 """计算Gram矩阵""" 29 gmatrix = np.zeros((self.numsamples, self.numsamples)) 30 for i in range(self.numsamples): 31 for j in range(self.numsamples): 32 gmatrix[i][j] = np.dot(self.x[i, :], self.x[j, :]) 33 return gmatrix 34 35 def sign(self, i): 36 """计算f(x)""" 37 y = np.dot(self.w*self.y, self.gmatrix[:, i]) + self.b 38 return int(y) 39 40 def update(self, i): 41 """更新w和b""" 42 self.w[:, i] = self.w[:, i] + self.a 43 self.b = self.b + self.a * self.y[i] 44 45 def cal_w(self): 46 """计算最终的w""" 47 w = np.dot(self.w*self.y, self.x) 48 return w 49 50 def train(self): 51 """感知机模型训练""" 52 isfind = False 53 while not isfind: 54 count = 0 55 for i in range(self.numsamples): 56 if self.y[i]*self.sign(i) <= 0: 57 count += 1 58 print('误分类点为: ', self.x[i, :], '此时w和b分别为: ', self.cal_w(), ', ', self.b) 59 self.update(i) 60 if count == 0: 61 print('最终的w和b为: ', self. cal_w(), ', ', self.b) 62 isfind = True 63 weights = self.cal_w() 64 return weights, self.b 65 66 67 class Picture: 68 """数据可视化""" 69 70 def __init__(self, data, w, b): 71 """"初始化画图参数""" 72 self.w = w 73 self.b = b 74 plt.figure(1) 75 plt.title('Perceptron Learning Algorithm of Duality', size=20) 76 plt.xlabel('X0-axis', size=14) 77 plt.ylabel('X1-axis', size=14) 78 79 xdata = np.linspace(1, 5, 100) 80 ydata = self.expression(xdata) 81 plt.plot(xdata, ydata, c='r') 82 83 plt.scatter(data[0][0], data[0][1], s=50) 84 plt.scatter(data[1][0], data[1][1], s=50) 85 plt.scatter(data[2][0], data[2][1], s=50, marker='x') 86 plt.savefig('test.png', dpi=95) 87 88 def expression(self, xdata): 89 """计算超平面上的纵坐标""" 90 y = (-self.b - self.w[:, 0]*xdata) / self.w[:, 1] 91 return y 92 93 def show(self): 94 """画图""" 95 plt.show() 96 97 98 if __name__ == '__main__': 99 samples, labels = createdata() 100 perceptron = Perceptron(x=samples, y=labels) 101 weights, b = perceptron.train() 102 picture = Picture(samples, weights, b) 103 picture.show()
参考自:https://blog.csdn.net/u010626937/article/details/72896144