• [机器学习] 梯度下降算法


    梯度下降算法

    一、原理

    梯度下降算法是一个用来求解函数最小值的算法,关于算法的详细介绍就不细说了,直接来公式

    这里引用吴恩达机器学习的代价函数来说明

    • 代价函数

      [J( heta_0, heta_1) ]

    • 梯度下降算法

    [egin{align} heta_j &= heta_j - alphafrac{partial}{partial heta_j}J( heta_0, heta_1) \ j &= 0,1 end{align} ]

    二、实践(python)

    目标函数

    [f(x,y) = -e^{-(x^2+y^2)} ]

    import numpy as np
    import math
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    x = np.linspace(-2, 2, 50)
    y = np.linspace(-2, 2, 50)
    X,Y = np.meshgrid(x,y)
    Z = -np.exp(-(X**2 + Y**2))
    
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1, cmap = 'rainbow')
    plt.show()
    

    • 这里举例的函数比较简单,在(0, 0)取得最小值

    • 代码如下

      import math
      import numpy as np
      
      # x type: list, x[0], x[1]
      def grad_2d(x):
          temp0 = 2 * x[0] * math.exp(-(x[0]**2 + x[1]**2))
          temp1 = 2 * x[1] * math.exp(-(x[0]**2 + x[1]**2))
          return np.array([temp0, temp1])
      
      def gradient(grad, init_val = np.array([0, 0]), learning_rate = 0.01, precision = 0.0001, max_iters = 10000):
          print('init val:', init_val)
          cur_val = init_val
          for i in range(max_iters):
              grad_cur = grad(cur_val)
              if np.linalg.norm(grad_cur, ord=2) < precision:
                  break
              cur_val = cur_val - grad_cur * learning_rate
              print('第', i, '次迭代, 当前 x 为:', cur_val)
              
          print('min x = ', cur_val)
          return cur_val
      
      if __name__ == '__main__':
          gradient(grad_2d, init_val=np.array([1, -1]), learning_rate=0.2, precision=0.000001, max_iters=10000)
      
    • 运行结果

      init val: [ 1 -1]
      第 0 次迭代, 当前 x 为: [ 0.94586589 -0.94586589]
      第 1 次迭代, 当前 x 为: [ 0.88265443 -0.88265443]
      第 2 次迭代, 当前 x 为: [ 0.80832661 -0.80832661]
      第 3 次迭代, 当前 x 为: [ 0.72080448 -0.72080448]
      第 4 次迭代, 当前 x 为: [ 0.61880589 -0.61880589]
      第 5 次迭代, 当前 x 为: [ 0.50372222 -0.50372222]
      第 6 次迭代, 当前 x 为: [ 0.3824228 -0.3824228]
      第 7 次迭代, 当前 x 为: [ 0.26824673 -0.26824673]
      第 8 次迭代, 当前 x 为: [ 0.17532999 -0.17532999]
      第 9 次迭代, 当前 x 为: [ 0.10937992 -0.10937992]
      第 10 次迭代, 当前 x 为: [ 0.06666242 -0.06666242]
      第 11 次迭代, 当前 x 为: [ 0.04023339 -0.04023339]
      第 12 次迭代, 当前 x 为: [ 0.02419205 -0.02419205]
      第 13 次迭代, 当前 x 为: [ 0.01452655 -0.01452655]
      第 14 次迭代, 当前 x 为: [ 0.00871838 -0.00871838]
      第 15 次迭代, 当前 x 为: [ 0.00523156 -0.00523156]
      第 16 次迭代, 当前 x 为: [ 0.00313905 -0.00313905]
      第 17 次迭代, 当前 x 为: [ 0.00188346 -0.00188346]
      第 18 次迭代, 当前 x 为: [ 0.00113008 -0.00113008]
      第 19 次迭代, 当前 x 为: [ 0.00067805 -0.00067805]
      第 20 次迭代, 当前 x 为: [ 0.00040683 -0.00040683]
      第 21 次迭代, 当前 x 为: [ 0.0002441 -0.0002441]
      第 22 次迭代, 当前 x 为: [ 0.00014646 -0.00014646]
      第 23 次迭代, 当前 x 为: [ 8.78751305e-05 -8.78751305e-05]
      第 24 次迭代, 当前 x 为: [ 5.27250788e-05 -5.27250788e-05]
      第 25 次迭代, 当前 x 为: [ 3.16350474e-05 -3.16350474e-05]
      第 26 次迭代, 当前 x 为: [ 1.89810285e-05 -1.89810285e-05]
      第 27 次迭代, 当前 x 为: [ 1.13886171e-05 -1.13886171e-05]
      第 28 次迭代, 当前 x 为: [ 6.83317026e-06 -6.83317026e-06]
      第 29 次迭代, 当前 x 为: [ 4.09990215e-06 -4.09990215e-06]
      第 30 次迭代, 当前 x 为: [ 2.45994129e-06 -2.45994129e-06]
      第 31 次迭代, 当前 x 为: [ 1.47596478e-06 -1.47596478e-06]
      第 32 次迭代, 当前 x 为: [ 8.85578865e-07 -8.85578865e-07]
      第 33 次迭代, 当前 x 为: [ 5.31347319e-07 -5.31347319e-07]
      第 34 次迭代, 当前 x 为: [ 3.18808392e-07 -3.18808392e-07]
      min x =  [ 3.18808392e-07 -3.18808392e-07]
      
  • 相关阅读:
    JSP+Ajax站点开发小知识
    JavaScript向select下拉框中加入和删除元素
    debain install scim
    Xcode 5.1.1 与 Xcode 6.0.1 共存
    Oracle集合操作函数:Union、Union All、Intersect、Minus
    8皇后-----回溯法C++编程练习
    Copy-and-swap
    Android System Property 解析
    Android 仿PhotoShop调色板应用(二) 透明度绘制之AlphaPatternDrawable
    Android 仿PhotoShop调色板应用(一)概述
  • 原文地址:https://www.cnblogs.com/zou107/p/12584859.html
Copyright © 2020-2023  润新知