import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.axes3d import Axes3D # 公式 f(x,y) = 2x^2+6y^2+6xy+x+4y+8 def targetFunc(x,y): return 2*(x**2)+6*y**2+6*x*y+x+4*y+8 # 偏导 # f'x(x,y)=4x+6y+1 # f'y(x,y)=12y+6x+4 def derivativeFunc(x,y): rx = 4*x+6*y+1 ry = 12*y+6*x+4 return (rx,ry) pointList = [] def linerFunc(initPoint:tuple,targetFunc,derivativeFunc,step = 0.01,limitValue = 0.00000001,timeout=1000000,ax:Axes3D = None): count = 1 initPoint = np.array(initPoint) ro,do = targetFunc(*initPoint),np.array(derivativeFunc(*initPoint)) pointList.append((*initPoint, ro)) newPoint = initPoint-do*step rn,dn = targetFunc(*newPoint),np.array(derivativeFunc(*newPoint)) diff = np.abs(np.array(do-dn)) while (diff > limitValue).any() and count < timeout: # print(initPoint) initPoint = newPoint ro, do = targetFunc(*initPoint), np.array(derivativeFunc(*initPoint)) newPoint = np.where(np.abs(do*step) >= limitValue,initPoint-do*step,initPoint) rn, dn = targetFunc(*newPoint), np.array(derivativeFunc(*newPoint)) diff = np.abs(np.array(do - dn)) pointList.append((*initPoint, ro)) count+=1 pass print("最终运算次数为 : {0}".format(count)) return rn,newPoint pass if __name__=="__main__": x,y = np.linspace(-2,23,100),np.linspace(-2,23,100) x,y = np.meshgrid(x,y) fxy=targetFunc(x,y) fig = plt.figure() ax = Axes3D(fig) ax.plot_surface(x, y, fxy) limitValue,limitPoint = linerFunc((20,20),targetFunc,derivativeFunc,ax=ax) ax.scatter(*(np.array(pointList).T),c='r',s=20) print("该函数在({0},{1})处有驻点,值为{2}".format(limitPoint[0],limitPoint[1],limitValue)) ax.legend() plt.show() pass