• [转]统计学习方法—chapter2—感知机算法实现


    描述:李航《统计学习方法》第二章感知机算法实现(Python)

    原始形式:

     1 # _*_ encoding:utf-8 _*_
     2 
     3 import numpy as np
     4 import matplotlib.pyplot as plt
     5 
     6 
     7 def createdata():
     8     """创建数据集和相应类标记"""
     9     samples = np.array([[3, 3], [4, 3], [1, 1]])
    10     labels = np.array([1, 1, -1])
    11     return samples, labels
    12 
    13 
    14 
    15 class Perceptron:
    16     """感知机模型"""
    17 
    18     def __init__(self, x, y, a=1):
    19         self.x = x
    20         self.y = y
    21         self.w = np.zeros((x.shape[1], 1))
    22         self.b = 0
    23         self.a = 1  #学习率
    24         self.numsamples = self.x.shape[0]
    25         self.numfeatures = self.x.shape[1]
    26 
    27     def sign(self, w, b, x):
    28         """计算某样本点的f(x)"""
    29         y = np.dot(x, w) + b
    30         return int(y)
    31 
    32     def update(self, label_i, data_i):
    33         """更新w和b"""
    34         tmp = label_i * self.a * data_i
    35         tmp = tmp.reshape(self.w.shape)
    36         self.w = tmp + self.w
    37         self.b = self.b + label_i * self.a
    38 
    39     def train(self):
    40         """训练感知机模型"""
    41         isfind = False
    42         while not isfind:
    43             count = 0
    44             for i in range(self.numsamples):
    45                 tmp = self.sign(self.w, self.b, self.x[i, :])
    46                 if tmp * self.y[i] <= 0:
    47                     print('误分类点为: ', self.x[i, :], '此时的w和b为: ', self.w, self.b)
    48                     count += 1
    49                     self.update(self.y[i], self.x[i, :])
    50             if count == 0:
    51                 print('最终训练得到的w和b为: ', self.w, self.b)
    52                 isfind = True
    53         return self.w, self.b
    54 
    55 
    56 
    57 class Picture:
    58     """数据可视化"""
    59 
    60     def __init__(self, data, w, b):
    61         """初始化参数"""
    62         self.b = b
    63         self.w = w
    64         plt.figure(1)
    65         plt.title('Perceptron Learning Algorithm', size= 14)
    66         plt.xlabel('x0-axis', size=14)
    67         plt.ylabel('x1-axis', size=14)
    68 
    69         xData = np.linspace(0, 5, 100)
    70         yData = self.expression(xData)
    71         plt.plot(xData, yData, color='r', label='sample data')
    72 
    73         plt.scatter(data[0][0], data[0][1], c='r', s=50)
    74         plt.scatter(data[1][0], data[1][1], c='g', s=50)
    75         plt.scatter(data[2][0], data[2][1], s=50, c='b', marker='x')
    76  
    77         plt.savefig('original.png', dpi=75)
    78 
    79     def expression(self, x):
    80         """计算超平面上对应的纵坐标"""
    81         y = (-self.b - self.w[0] * x) / self.w[1]
    82         return y
    83 
    84     def show(self):
    85         """画图"""
    86         plt.show()
    87 
    88 
    89 if __name__ == '__main__':
    90     samples, labels = createdata()
    91     myperceptron = Perceptron(samples, labels)
    92     weights, bias = myperceptron.train()
    93     picture = Picture(samples, weights, bias)
    94     picture.show()

    对偶形式:

      1 # _*_ encoding:utf-8 _*_
      2 
      3 import numpy as np
      4 import matplotlib.pyplot as plt
      5 
      6 def createdata():
      7     """创建数据集和相应的类标记"""
      8     samples = np.array([[3, 3], [4, 3], [1, 1]])
      9     labels = np.array([1, 1, -1])
     10     return samples, labels
     11 
     12 
     13 class Perceptron:
     14     """感知机模型"""
     15 
     16     def __init__(self, x, y, a=1):
     17         """初始化数据集,标记,学习率,参数等"""
     18         self.x = x
     19         self.y = y
     20         self.w = np.zeros((1, x.shape[0]))
     21         self.b = 0
     22         self.a = a
     23         self.numsamples = self.x.shape[0]
     24         self.numfeatures = self.x.shape[1]
     25         self.gmatrix = self.gMatrix()
     26 
     27     def gMatrix(self):
     28         """计算Gram矩阵"""
     29         gmatrix = np.zeros((self.numsamples, self.numsamples))
     30         for i in range(self.numsamples):
     31             for j in range(self.numsamples):
     32                 gmatrix[i][j] = np.dot(self.x[i, :], self.x[j, :])
     33         return gmatrix
     34 
     35     def sign(self, i):
     36         """计算f(x)"""
     37         y = np.dot(self.w*self.y, self.gmatrix[:, i]) + self.b
     38         return int(y)
     39 
     40     def update(self, i):
     41         """更新w和b"""
     42         self.w[:, i] = self.w[:, i] + self.a
     43         self.b = self.b + self.a * self.y[i]
     44 
     45     def cal_w(self):
     46         """计算最终的w"""
     47         w = np.dot(self.w*self.y, self.x)
     48         return w
     49 
     50     def train(self):
     51         """感知机模型训练"""
     52         isfind = False
     53         while not isfind:
     54             count = 0
     55             for i in range(self.numsamples):
     56                 if self.y[i]*self.sign(i) <= 0:
     57                     count += 1
     58                     print('误分类点为: ', self.x[i, :], '此时w和b分别为: ', self.cal_w(), ', ', self.b)
     59                     self.update(i)
     60             if count == 0:
     61                 print('最终的w和b为: ', self. cal_w(), ', ', self.b)
     62                 isfind = True
     63         weights = self.cal_w()
     64         return weights, self.b
     65 
     66 
     67 class Picture:
     68     """数据可视化"""
     69 
     70     def __init__(self, data, w, b):
     71         """"初始化画图参数"""
     72         self.w = w
     73         self.b = b
     74         plt.figure(1)
     75         plt.title('Perceptron Learning Algorithm of Duality', size=20)
     76         plt.xlabel('X0-axis', size=14)
     77         plt.ylabel('X1-axis', size=14)
     78 
     79         xdata = np.linspace(1, 5, 100)
     80         ydata = self.expression(xdata)
     81         plt.plot(xdata, ydata, c='r')
     82 
     83         plt.scatter(data[0][0], data[0][1], s=50)
     84         plt.scatter(data[1][0], data[1][1], s=50)
     85         plt.scatter(data[2][0], data[2][1], s=50, marker='x')
     86         plt.savefig('test.png', dpi=95)
     87 
     88     def expression(self, xdata):
     89         """计算超平面上的纵坐标"""
     90         y = (-self.b - self.w[:, 0]*xdata) / self.w[:, 1]
     91         return y
     92 
     93     def show(self):
     94         """画图"""
     95         plt.show()
     96 
     97 
     98 if __name__ == '__main__':
     99     samples, labels = createdata()
    100     perceptron = Perceptron(x=samples, y=labels)
    101     weights, b = perceptron.train()
    102     picture = Picture(samples, weights, b)
    103     picture.show()

    参考自:https://blog.csdn.net/u010626937/article/details/72896144

  • 相关阅读:
    Tesseract-OCR
    chrome浏览器插件推荐
    安装配置sublime Text 3 快捷键方式
    远程桌面与远程控制
    一个C#的XML数据库访问类
    WPF小程序:贪吃蛇
    恐惧源于一知半解
    8条佛曰 66句禅语
    自动开机 双网卡网络唤醒
    C#,Java,C++中的finally关键字
  • 原文地址:https://www.cnblogs.com/OoycyoO/p/9538055.html
Copyright © 2020-2023  润新知