step 1
用高斯分布生成两类点
1 class Point3: 2 def __init__(self): 3 self.x = random.gauss(50, 10) 4 self.y = random.gauss(50, 10) 5 6 self.label = -1 7 self.color = 'r' 8 9 class Point4: 10 def __init__(self): 11 self.x = random.gauss(90, 10) 12 self.y = random.gauss(90, 10) 13 14 self.label = 1 15 self.color = 'b'
step 2
画一条初始直线,先定义两个点(x1, 0)和(x2, 100),x1属于(0, 50),x2属于(50, 100),有了两个点之后,画出一条直线
1 class Line: 2 def __init__(self): 3 self.x1 = random.randint(MIN, MAX//2) # MAX=100 MIN=0 (0, 50) 随机生成一个整数 4 self.x2 = random.randint(MAX//2, MAX) # MAX=100 MIN=0 (50, 100) 5 self.y1 = 0 6 self.y2 = 100 7 8 self.x = [self.x1, self.x2] 9 self.y = [self.y1, self.y2] 10 11 self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1) 12 self.w2 = 1 13 self.b = -(self.w1 * self.x1) + self.w2 * self.y1
step 3
判断误分类点
正确分类1:w1*x+w2*y+b>0且label=1
正确分类2:w1*x+w2*y+b<0且label=-1
1 def sign(self, point): 2 # print(self.w1 * point.x + self.w2 * point.y + self.b) 3 # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b)) 4 return point.label * (self.w1 * point.x + self.w2 * point.y + self.b)
step 4
有了更新后的w1、w2和b之后,更新一条新的直线。
首先,需要先找到两个点,此时y1=0, y2=100不变,则我们只需找到对应的x1,x2即可。
1 def update(self): 2 self.x1 = -self.b / self.w1 3 self.x2 = (-self.b - self.w2 * self.y2) / self.w1 4 self.x = [self.x1, self.x2] 5 self.y = [self.y1, self.y2]
step 5
w1、w2和b的更新规则,参考博文支持向量机http://www.carefree0910.com/posts/d455305a/
1 def preceptron_base_dis(all_points): 2 line = Line() 3 plt.plot(line.x, line.y, 'g--', linewidth=1) 4 Flag = True 5 while True: 6 Flag = True 7 for point in all_points: 8 if line.sign(point) < 1: # 只有误分类点才更新 9 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x 10 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y 11 line.b = line.b + step * C * point.label 12 Flag = False 13 if Flag: 14 break 15 line.update() 16 #plt.plot(l.x, l.y, 'y--', linewidth=1) 17 plt.plot(line.x, line.y, '.-', linewidth=1) 18 plt.show()
全部代码汇总
1 import matplotlib.pyplot as plt 2 import numpy 3 import random 4 import sys 5 6 MAX=100 7 MIN=0 8 POINT_NUM=20 9 step=0.01 10 C = 0.1 11 12 class Point: 13 def __init__(self): 14 self.x = random.uniform(MIN, MAX) 15 self.y = random.uniform(MIN, MAX) 16 17 if self.x > self.y: 18 self.label = 1 19 self.color = 'b' 20 else: 21 self.label = -1 22 self.color = 'r' 23 class Point2: 24 def __init__(self): 25 self.x = random.randint(MIN, MAX) 26 if self.x > MAX // 2: 27 self.y = random.randint(0, MAX // 4) 28 else: 29 self.y = random.randint(MAX * 2 // 4, MAX) 30 31 if self.x > self.y: 32 self.label = 1 33 self.color = 'b' 34 else: 35 self.label = -1 36 self.color = 'r' 37 38 class Point3: 39 def __init__(self): 40 self.x = random.gauss(50, 10) 41 self.y = random.gauss(50, 10) 42 43 self.label = -1 44 self.color = 'r' 45 46 class Point4: 47 def __init__(self): 48 self.x = random.gauss(90, 10) 49 self.y = random.gauss(90, 10) 50 51 self.label = 1 52 self.color = 'b' 53 class Line: 54 def __init__(self): 55 self.x1 = random.randint(MIN, MAX//2) # MAX=100 MIN=0 (0, 50) 随机生成一个整数 56 self.x2 = random.randint(MAX//2, MAX) # MAX=100 MIN=0 (50, 100) 57 self.y1 = 0 58 self.y2 = 100 59 60 self.x = [self.x1, self.x2] 61 self.y = [self.y1, self.y2] 62 63 self.w1 = -(self.y2 - self.y1) / (self.x2 - self.x1) 64 self.w2 = 1 65 self.b = -(self.w1 * self.x1) + self.w2 * self.y1 66 67 def sign(self, point): 68 # print(self.w1 * point.x + self.w2 * point.y + self.b) 69 # print(point.label * (self.w1 * point.x + self.w2 * point.y + self.b)) 70 return point.label * (self.w1 * point.x + self.w2 * point.y + self.b) 71 72 def update(self): 73 self.x1 = -self.b / self.w1 74 self.x2 = (-self.b - self.w2 * self.y2) / self.w1 75 self.x = [self.x1, self.x2] 76 self.y = [self.y1, self.y2] 77 78 79 def initialPoint(): 80 plt.figure() 81 all_point = [] 82 for idx in range(POINT_NUM): 83 p = Point3() 84 plt.plot(p.x, p.y, p.color + 'o', label="point") 85 all_point.append(p) 86 87 for idx in range(POINT_NUM): 88 p = Point4() 89 plt.plot(p.x, p.y, p.color + 'o', label="point") 90 all_point.append(p) 91 return all_point 92 93 def preceptron_base_dis(all_points): 94 line = Line() 95 plt.plot(line.x, line.y, 'g--', linewidth=1) 96 Flag = True 97 while True: 98 Flag = True 99 for point in all_points: 100 if line.sign(point) < 1: # 只有误分类点才更新 101 line.w1 = (1 - step) * line.w1 + step * C * point.label * point.x 102 line.w2 = (1 - step) * line.w2 + step * C * point.label * point.y 103 line.b = line.b + step * C * point.label 104 Flag = False 105 if Flag: 106 break 107 line.update() 108 #plt.plot(l.x, l.y, 'y--', linewidth=1) 109 plt.plot(line.x, line.y, '.-', linewidth=1) 110 plt.show() 111 112 def preceptron(all_points): 113 line = Line() 114 plt.plot(line.x, line.y, 'g--', linewidth=1) 115 Flag = True 116 while True: 117 Flag = True 118 for point in all_points: 119 if line.sign(point) <= 0: 120 line.w1 += step * point.label * point.x 121 line.w2 += step * point.label * point.y 122 line.b += step * point.label 123 Flag = False 124 if Flag: 125 break 126 line.update() 127 #plt.plot(line.x, line.y, 'y--', linewidth=1) 128 plt.plot(line.x, line.y, 'o-', linewidth=1) 129 plt.show() 130 131 all_points = initialPoint() 132 preceptron_base_dis(all_points)