• 梯度下降算法之方程求解


    从上个月专攻机器学习,从本篇开始,我会陆续写机器学习的内容,都是我的学习笔记。

    gradienet

    问题

    梯度下降算法用于求数学方程的极大值极小值问题,这篇文章讲解如何利用梯度下降算法求解方程 (x^5+e^x+3x−3=0) 的根;

    方法

    首先来解决第一个问题,从方程的形式我们就能初步判断,它很可能没有闭式解。我能想到的最直观的解决方法就是画出函数图,函数图与 x 轴的交点就是方程的解,那先画个图看看

    从函数图像大体可以判断,方程的根在 0 附近,但是很明显 0 不是方程的根,看图只能猜出个大概,那怎么做才能得到更精确的解呢?

    有一个可行的方法在 x = 0 附近找一堆很接近的数字,比如 [−0.5:0.05:1][−0.5:0.05:1],一个个代入方程的左边,看看它的值离 0 有多近:距离 0 越近,说明我们选取的值离方程的根也越近。数学上定义两个数距离就是绝对值,但是因为绝对值不便于计算,所以将其替换成等价的差的平方,即 F(x)=(f(x)−0)2F(x)=(f(x)−0)2,以此度量结果距离 0 的程度,称之为损失函数

    我们代入计算得到如下的结果

    x: -0.500, f(x): -3.9247, F(x): 15.4034
    x: -0.450, f(x): -3.7308, F(x): 13.9191
    x: -0.400, f(x): -3.5399, F(x): 12.5310
    x: -0.350, f(x): -3.3506, F(x): 11.2263
    x: -0.300, f(x): -3.1616, F(x): 9.9958
    x: -0.250, f(x): -2.9722, F(x): 8.8338
    x: -0.200, f(x): -2.7816, F(x): 7.7372
    x: -0.150, f(x): -2.5894, F(x): 6.7048
    x: -0.100, f(x): -2.3952, F(x): 5.7369
    x: -0.050, f(x): -2.1988, F(x): 4.8346
    x: -0.000, f(x): -2.0000, F(x): 4.0000
    x: 0.050, f(x): -1.7987, F(x): 3.2354
    x: 0.100, f(x): -1.5948, F(x): 2.5434
    x: 0.150, f(x): -1.3881, F(x): 1.9268
    x: 0.200, f(x): -1.1783, F(x): 1.3883
    x: 0.250, f(x): -0.9650, F(x): 0.9312
    x: 0.300, f(x): -0.7477, F(x): 0.5591
    x: 0.350, f(x): -0.5257, F(x): 0.2763
    x: 0.400, f(x): -0.2979, F(x): 0.0888
    x: 0.450, f(x): -0.0632, F(x): 0.0040
    x: 0.500, f(x): 0.1800, F(x): 0.0324
    x: 0.550, f(x): 0.4336, F(x): 0.1880
    x: 0.600, f(x): 0.6999, F(x): 0.4898
    x: 0.650, f(x): 0.9816, F(x): 0.9635
    x: 0.700, f(x): 1.2818, F(x): 1.6431
    x: 0.750, f(x): 1.6043, F(x): 2.5738
    x: 0.800, f(x): 1.9532, F(x): 3.8151
    x: 0.850, f(x): 2.3334, F(x): 5.4445
    x: 0.900, f(x): 2.7501, F(x): 7.5630
    x: 0.950, f(x): 3.2095, F(x): 10.3008
    

    可以看出,x = 0.5,结果已经很接近 0 了,方程的根应该在 0.45~0.50 之间,而且 0.45 时,F(x) 的值更小,说明离 0.45 距离更近。接下来,一个可行的方法是将这段再细分成更小的区间,再如上面这样尝试,直到结果满意为止。但是这样做太过机械,每次需要手动调整区间和步长,有没有一种方法可以自动调整呢?

    再回到我们的问题,求解方程的根,就是找到一个点使得损失函数最小,我们画出来这个函数的曲线看看

    我们假定方程的根是 x0x0,除了 x0x0,其他点的函数值都比该点处的高,而且从两边向内,越是靠近 x0x0,函数的值越接近 0。而且可以发现,从两边向 x0x0 移动,方向刚好就是该点处切线的斜率 F′(x)F′(x) 的相反数。

    斜率

    于是得到启发,挑选一个初始点,沿着该点的斜率相反的方向迭代,必然越来越靠近方程的根,所以有下面的算法:

    1. 对于方程 f(x)=0f(x)=0,舍设定损失函数 F(x)=(f(x)−0)2F(x)=(f(x)−0)2;
    2. 设定一个初值 x0x0,代入损失函数求得结果,如果大于 0,那么找到一个新的值 x1=x0−αF′(x0)x1=x0−αF′(x0),考察损失函数是否为 0;
    3. 反复迭代第 2 步,直到达到满意的精度为止。

    上面的算法中,有三个参数需要注意:

    • αα,称为学习率,代表了曲线逼近的速度,这个参数可以自己设定;
    • 迭代次数,第 2 步运行的次数,迭代次数越多,我们离理想的结果越接近;
    • 精度,定义为 |F(x)||F(x)|,表示迭代的效果

    这三个参数中,迭代次数和精度可以作为迭代的终止条件,比如迭代次数达到 10000 次或者精度达到某个很小的数值 σσ 就终止运行。

    下面我们使用 python 程序来演示该算法的效果:

    # _*_ coding: utf-8 _*_
    import numpy as np
    
    # 定义函数f(x)
        e = 2.71828182845904590
        return x**5 + e**x + 3*x - 3
    
    #定义损失函数
    def loss_fun(x):
        return (problem(x) - 0)**2
    
    #计算损失函数的斜率
    def slope_fx(x):
        delta  = 0.0000001;
        return (loss_fun(x+delta) - loss_fun(x-delta))/(2.0*delta)
    
    #代入f(x),计算数值
    def calcu_loss_fun(x,maxTimes,alpha):
            for i in range(maxTimes):
                x = x - slope_fx(x)*alpha;
                print 'times %d, x: %.13f, f(x): %.13f' % (i, x, problem(x))
    alpha = 0.01
    maxTimes = 100
    x = 0.0;
    
    calcu_loss_fun(x,maxTimes,alpha)
    

    其中的slope_fx计算方程的斜率,利用导数定义 f′(x)=f(x+Δx)−f(x)Δxf′(x)=f(x+Δx)−f(x)Δx。程序计算结果如下

    times 1, x: 0.2724712244717, f(x): -0.8678788871194
    times 2, x: 0.3478163723702, f(x): -0.5354882897920
    times 3, x: 0.3958941025006, f(x): -0.3168805921512
    times 4, x: 0.4251012218626, f(x): -0.1810687680246
    times 5, x: 0.4420964369242, f(x): -0.1008566369730
    times 6, x: 0.4516717013511, f(x): -0.0552506486831
    times 7, x: 0.4569525930429, f(x): -0.0299651603458
    times 8, x: 0.4598276021739, f(x): -0.0161585445219
    times 9, x: 0.4613811940466, f(x): -0.0086856358075
    times 10, x: 0.4622172450759, f(x): -0.0046606160693
    times 11, x: 0.4626661379649, f(x): -0.0024984737671
    times 12, x: 0.4629068614830, f(x): -0.0013387061269
    times 13, x: 0.4630358664583, f(x): -0.0007170954782
    times 14, x: 0.4631049762781, f(x): -0.0003840652503
    times 15, x: 0.4631419923255, f(x): -0.0002056832476
    times 16, x: 0.4631618165349, f(x): -0.0001101474736
    times 17, x: 0.4631724329502, f(x): -0.0000589848326
    times 18, x: 0.4631781181683, f(x): -0.0000315864570
    times 19, x: 0.4631811626230, f(x): -0.0000169144811
    times 20, x: 0.4631827929259, f(x): -0.0000090576372
    times 21, x: 0.4631836659475, f(x): -0.0000048503201
    times 22, x: 0.4631841334466, f(x): -0.0000025973198
    times 23, x: 0.4631843837899, f(x): -0.0000013908497
    times 24, x: 0.4631845178473, f(x): -0.0000007447918
    times 25, x: 0.4631845896343, f(x): -0.0000003988315
    times 26, x: 0.4631846280757, f(x): -0.0000002135719
    times 27, x: 0.4631846486609, f(x): -0.0000001143664
    times 28, x: 0.4631846596842, f(x): -0.0000000612425
    times 29, x: 0.4631846655870, f(x): -0.0000000327950
    times 30, x: 0.4631846687480, f(x): -0.0000000175615
    times 31, x: 0.4631846704407, f(x): -0.0000000094041
    times 32, x: 0.4631846713471, f(x): -0.0000000050358
    times 33, x: 0.4631846718325, f(x): -0.0000000026967
    times 34, x: 0.4631846720924, f(x): -0.0000000014440
    times 35, x: 0.4631846722316, f(x): -0.0000000007733
    times 36, x: 0.4631846723061, f(x): -0.0000000004141
    times 37, x: 0.4631846723460, f(x): -0.0000000002217
    times 38, x: 0.4631846723674, f(x): -0.0000000001187
    times 39, x: 0.4631846723788, f(x): -0.0000000000636
    times 40, x: 0.4631846723850, f(x): -0.0000000000340
    times 41, x: 0.4631846723882, f(x): -0.0000000000182
    times 42, x: 0.4631846723900, f(x): -0.0000000000098
    times 43, x: 0.4631846723909, f(x): -0.0000000000052
    times 44, x: 0.4631846723914, f(x): -0.0000000000028
    times 45, x: 0.4631846723917, f(x): -0.0000000000015
    times 46, x: 0.4631846723919, f(x): -0.0000000000008
    times 47, x: 0.4631846723919, f(x): -0.0000000000004
    times 48, x: 0.4631846723920, f(x): -0.0000000000002
    times 49, x: 0.4631846723920, f(x): -0.0000000000001
    times 50, x: 0.4631846723920, f(x): -0.0000000000001
    times 51, x: 0.4631846723920, f(x): -0.0000000000000
    times 52, x: 0.4631846723920, f(x): -0.0000000000000
    times 53, x: 0.4631846723920, f(x): -0.0000000000000
    times 54, x: 0.4631846723920, f(x): -0.0000000000000
    

    迭代 52 次,就已经达到了理想的效果。

    参考资料

  • 相关阅读:
    localStorage存储数组以及取数组方法
    jq选择CheckBox进行排序
    js定时函数,定时改变字体的大小
    JQuery Datatable用法
    WebSocket实战
    代码段
    黎活明给程序员的忠告 收藏
    雅砻江后勤项目经验总结
    Java泛型方法
    回忆,梦的开始
  • 原文地址:https://www.cnblogs.com/bugxch/p/14190966.html
Copyright © 2020-2023  润新知