• 感知机(perceptron) 学习笔记


    前言:偶尔回想自己学过的算法,反复看反复忘,故又重复看一遍并记下笔记,供后续学习参考。


    感知机是一个二分类算法,是深度学习的简化版,只有一层网络,建模思想跟支持向量机类似,是两算法的基础。

    分类原理 :y =(编辑公式还有更好的办法吗?), 满足wx+b>=0的输入被分类为标签1,否则被分类为标签-1。

    建模:分类的是超平面wx+b=0,以输入点到超平面的距离为判断标准。距离公式为d = |wx+b|/ ||w|| = y(wx+b)/||w||。||.||是二范数。

    损失函数:损失的自然选择是错误分类的总点数,但是这样的损失不是参数w,b的连续可到函数,不易优化,所以将错误分类的总点数到超平面的距离的总距离定义为损失函数,因此损失函数deltaL =  -1*y(wx+b)/||w||,此处根据SVM中几何间隔,知需要对w进行约束,避免b与w同比例增加,参数变了但是超平面本身并没有变动,所以约定||w||=1,损失函数为deltaL =  -1*y(wx+b),更细致的推理可参见其他资料,(我用||w||计算损失进行训练没见特别冥想的错误?)。

    优化算法:优化算法是为了求得最优解,这里使用随机梯度下降算法SGD,已知损失函数,根据其对变量w,b求导得到,dw = -yx, db = -y,设学习率为u = 0.5。w= w-dw = w+uyx, b = b - db = b+uy

    代码如下:

    import math
    import random
    '''
    感知机是二分类模型
    判断条件 y = 1 when w1x1+w2x2+b>0
            y = -1 when w1x1+w2x2+b<0
    
    '''
    
    #z准备数据集
    X = [[3,2], [12,10], [33,62], [8,16], [23,45], [7,13], [78,65], [35,54], [77,55], [89,23]]
    Y = [-1, -1, 1, -1, -1, -1,1, -1, 1, 1]
    
    #训练网络/推理逻辑
    # y= x1*w1 + x2*w2 +b
    
    w1 = 0.1
    w2 = 0.1
    b = 0
    u = 0.5
    # for [x1,x2] in X:
    #     y = x1*w1 + x2*w2 +b
    
    
    #计算损失函数
    #点到平面wx+b=0 的距离 d = |wx+b| / ||w||,对于误分类的则需要被统计作为损失,目标是使损失变为0
    #d = -yi*(wxi+b)/math.sqrt(sum(pow(w,2)))
    
    #优化算法 梯度函数
    #dw = yixi ; db = yi, u是步长(学习率)
    #w = w+ udw
    #b = b+ udb
    
    #每个epoch遇到误分类点后更新一次变量
    deltal = 0
    for i in range(10000):
        print(i)
        deltal = 0
        for i in range(20):
    
            index = random.randint(0,len(X)-1)
            # print((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index])
            # print((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]*(-1))
            if((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]<=0):
                print("error")
                w1 = w1+ u*Y[index]*X[index][0]
                w2 = w2 + u*Y[index]*X[index][1]
                b = b + u*Y[index]
                # print(math.sqrt(pow(w1,2)+ pow(w2,2)))
                deltal =deltal + (X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]*(-1)/math.sqrt(pow(w1,2)+ pow(w2,2))
                # print(deltal)
        print('损失函数是:{}, w1 is:{}, w2 is:{}, b is:{}'.format(deltal, w1, w2, b))
        if(deltal==0):
            break
    

      

         得到的结果如下:

    0
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-58.3676307788504, w1 is:9.600000000000001, w2 is:-11.899999999999999, b is:-2.5
    1
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-79.6906079022878, w1 is:16.1, w2 is:-27.4, b is:-5.0
    2
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-181.72511941015193, w1 is:19.1, w2 is:-35.4, b is:-7.5
    3
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-134.1050074696168, w1 is:52.1, w2 is:-21.4, b is:-9.0
    4
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-81.79882696995062, w1 is:21.1, w2 is:-49.9, b is:-12.5
    5
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-119.81313066808934, w1 is:23.6, w2 is:-19.9, b is:-14.0
    6
    error
    error
    error
    error
    损失函数是:-53.31304070579386, w1 is:21.1, w2 is:-17.4, b is:-15.0
    7
    error
    error
    error
    error
    error
    损失函数是:-41.93949987798639, w1 is:31.6, w2 is:9.600000000000001, b is:-16.5
    8
    error
    error
    error
    损失函数是:-5.054308599796722, w1 is:17.1, w2 is:-14.899999999999999, b is:-18.0
    9
    error
    error
    损失函数是:-5.061614611137055, w1 is:9.600000000000001, w2 is:-20.9, b is:-19.0
    10
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-176.0012076253418, w1 is:30.1, w2 is:6.100000000000001, b is:-21.5
    11
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-125.48127046860397, w1 is:37.1, w2 is:-13.899999999999999, b is:-24.5
    12
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-125.1563990295305, w1 is:7.100000000000001, w2 is:-41.9, b is:-28.0
    13
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-99.9588692981525, w1 is:31.6, w2 is:-34.4, b is:-30.0
    14
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-123.83992037821682, w1 is:39.1, w2 is:-43.4, b is:-32.0
    15
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-107.52054118479901, w1 is:38.099999999999994, w2 is:-61.4, b is:-35.0
    16
    error
    error
    error
    error
    error
    error
    损失函数是:-12.707209628891478, w1 is:53.599999999999994, w2 is:-18.4, b is:-36.0
    17
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-23.948721417884677, w1 is:40.099999999999994, w2 is:-28.4, b is:-38.5
    18
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-139.31131437239534, w1 is:31.099999999999994, w2 is:-56.4, b is:-40.5
    19
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-95.98994407835531, w1 is:48.099999999999994, w2 is:-62.9, b is:-42.0
    20
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-92.55322975282904, w1 is:78.6, w2 is:-39.9, b is:-43.5
    21
    error
    error
    损失函数是:-17.88492872798486, w1 is:55.099999999999994, w2 is:-71.9, b is:-44.5
    22
    error
    error
    损失函数是:-18.30431876055715, w1 is:54.099999999999994, w2 is:-67.9, b is:-44.5
    23
    error
    error
    损失函数是:-19.673217617324354, w1 is:53.099999999999994, w2 is:-63.900000000000006, b is:-44.5
    24
    error
    error
    损失函数是:-21.10211583473629, w1 is:52.099999999999994, w2 is:-59.900000000000006, b is:-44.5
    25
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-65.34583910969604, w1 is:66.6, w2 is:-47.400000000000006, b is:-46.5
    26
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-74.24796184936555, w1 is:81.6, w2 is:-42.400000000000006, b is:-49.0
    27
    error
    error
    error
    损失函数是:-34.124857156842516, w1 is:63.099999999999994, w2 is:-65.4, b is:-49.5
    28
    error
    error
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-20.592229958037294, w1 is:52.099999999999994, w2 is:-56.400000000000006, b is:-52.5
    29
    error
    error
    error
    error
    error
    error
    error
    error
    损失函数是:-15.858697189511131, w1 is:58.099999999999994, w2 is:-24.900000000000006, b is:-54.5
    30
    error
    error
    error
    error
    error
    损失函数是:-59.07572886107646, w1 is:52.099999999999994, w2 is:-24.900000000000006, b is:-55.0
    31
    error
    error
    error
    error
    error
    损失函数是:-33.3306474537016, w1 is:35.099999999999994, w2 is:-38.900000000000006, b is:-56.5
    32
    损失函数是:0, w1 is:35.099999999999994, w2 is:-38.900000000000006, b is:-56.5
    
    Process finished with exit code 0

    重复运行多次,会有不同的结果。感知机算法由于采用不同的初值和选取不同的误分类点,解可以不同

    以上为个人理解,如有不对的地方,欢迎交流指正~

  • 相关阅读:
    循环获取数据
    implode
    获取二维数组中的值
    根据id获取某一类的最大最小值
    array_column的作用
    用curl模拟夹带cookie的http请求
    phpunit——执行测试文件和测试文件中的某一个函数
    call_user_func
    9 [面向对象]-内置方法
    8 [面向对象]-反射
  • 原文地址:https://www.cnblogs.com/xiaoheizi-12345/p/13191641.html
Copyright © 2020-2023  润新知