• 两个变量(可支持多自变量)的简单梯度下降


    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.axes3d import Axes3D
    
    #   公式 f(x,y) = 2x^2+6y^2+6xy+x+4y+8
    def targetFunc(x,y):
        return 2*(x**2)+6*y**2+6*x*y+x+4*y+8
    
    #   偏导
    #   f'x(x,y)=4x+6y+1
    #   f'y(x,y)=12y+6x+4
    def derivativeFunc(x,y):
        rx = 4*x+6*y+1
        ry = 12*y+6*x+4
        return (rx,ry)
    
    pointList = []
    
    def linerFunc(initPoint:tuple,targetFunc,derivativeFunc,step = 0.01,limitValue = 0.00000001,timeout=1000000,ax:Axes3D = None):
        count = 1
        initPoint = np.array(initPoint)
        ro,do = targetFunc(*initPoint),np.array(derivativeFunc(*initPoint))
        pointList.append((*initPoint, ro))
    
        newPoint = initPoint-do*step
        rn,dn = targetFunc(*newPoint),np.array(derivativeFunc(*newPoint))
    
        diff = np.abs(np.array(do-dn))
    
        while (diff > limitValue).any() and count < timeout:
            # print(initPoint)
            initPoint = newPoint
            ro, do = targetFunc(*initPoint), np.array(derivativeFunc(*initPoint))
    
            newPoint = np.where(np.abs(do*step) >= limitValue,initPoint-do*step,initPoint)
            rn, dn = targetFunc(*newPoint), np.array(derivativeFunc(*newPoint))
            diff = np.abs(np.array(do - dn))
    
            pointList.append((*initPoint, ro))
            count+=1
            pass
        print("最终运算次数为 : {0}".format(count))
        return rn,newPoint
        pass
    
    
    if __name__=="__main__":
        x,y = np.linspace(-2,23,100),np.linspace(-2,23,100)
        x,y = np.meshgrid(x,y)
        fxy=targetFunc(x,y)
    
        fig = plt.figure()
        ax = Axes3D(fig)
    
        ax.plot_surface(x, y, fxy)
        limitValue,limitPoint = linerFunc((20,20),targetFunc,derivativeFunc,ax=ax)
        ax.scatter(*(np.array(pointList).T),c='r',s=20)
        print("该函数在({0},{1})处有驻点,值为{2}".format(limitPoint[0],limitPoint[1],limitValue))
        ax.legend()
        plt.show()
        pass

  • 相关阅读:
    词法分析程序
    0909关于编译原理
    深度学习中图像检测的评价标准
    【 记忆网络 1 】 Memory Network
    ssm又乱码
    百度地图标注没了
    Fragment与Activity交互(使用Handler)
    在android里用ExpandableListView实现二层和三层列表
    java中outer的使用
    android中使用Http下载文件并保存到本地SD卡
  • 原文地址:https://www.cnblogs.com/dofstar/p/11462941.html
Copyright © 2020-2023  润新知