1 import numpy as np 2 from mpl_toolkits.mplot3d import Axes3D 3 import matplotlib.pyplot as plt 4 from matplotlib.colors import LinearSegmentedColormap 5 6 # 待求极值的函数 7 def himmelblau(t):# t[0]-->X; t[1]-->Y. 8 return (t[0] ** 2 + t[1] - 11) ** 2 + (t[0] + t[1] ** 2 - 7) ** 2 9 10 x = np.arange(-6, 6, 0.1) 11 y = np.arange(-6, 6, 0.1) 12 X, Y = np.meshgrid(x, y) 13 Z = himmelblau([X, Y]) 14 fig = plt.figure() 15 ax = fig.add_subplot(projection='3d')# ax = fig.gca(projection='3d') # ---> was deprecated in Matplotlib 3.4 16 ax.plot_surface(X, Y, Z) 17 ax.view_init(60, -30) 18 ax.set_xlabel('x') 19 ax.set_ylabel('y') 20 fig.show() 21 plt.show() 22 23 # function test 24 def jeshy(t): 25 return t*3+10 26 27 import torch 28 x = torch.tensor([0., 0.], requires_grad=True) 29 optimizer = torch.optim.Adam([x, ])# optim.Adam([var1, var2], lr=0.0001)# 优化器设置 ,并传入模型参数和相应的学习率 30 for step in range(20001): 31 f = himmelblau(x)# 前向传播 32 if step > 0: 33 optimizer.zero_grad()# 反向传播与优化# 清空上一步的残余更新参数值 34 f.backward(retain_graph=True)# 反向传播与优化# 反向传播 35 optimizer.step()# 反向传播与优化# 将参数更新值施加到函数f的parameters上 36 # f = jeshy(f) 37 if step % 1000 == 0:# 每迭代一定步骤,打印结果值 38 print('step:{}, x = {}, value = {}'.format(step, x.tolist(), f))
输出:
step:0, x = [0.0, 0.0], value = 170.0
step:1000, x = [1.270142912864685, 1.1183991432189941], value = 88.53223419189453
step:2000, x = [2.332378387451172, 1.9535712003707886], value = 13.766233444213867
step:3000, x = [2.8519949913024902, 2.114161968231201], value = 0.6711398363113403
step:4000, x = [2.981964111328125, 2.0271568298339844], value = 0.014927156269550323
step:5000, x = [2.9991261959075928, 2.0014777183532715], value = 3.9870232285466045e-05
step:6000, x = [2.999983549118042, 2.0000221729278564], value = 1.1074007488787174e-08
step:7000, x = [2.9999899864196777, 2.000013589859009], value = 4.150251697865315e-09
step:8000, x = [2.9999938011169434, 2.0000083446502686], value = 1.5572823031106964e-09
step:9000, x = [2.9999964237213135, 2.000005006790161], value = 5.256879376247525e-10
step:10000, x = [2.999997854232788, 2.000002861022949], value = 1.8189894035458565e-10
step:11000, x = [2.9999988079071045, 2.0000014305114746], value = 5.547917680814862e-11
step:12000, x = [2.9999992847442627, 2.0000009536743164], value = 1.6370904631912708e-11
step:13000, x = [2.999999523162842, 2.000000476837158], value = 5.6843418860808015e-12
step:14000, x = [2.999999761581421, 2.000000238418579], value = 1.8189894035458565e-12
step:15000, x = [3.0, 2.0], value = 0.0
step:16000, x = [3.0, 2.0], value = 0.0
step:17000, x = [3.0, 2.0], value = 0.0
step:18000, x = [3.0, 2.0], value = 0.0
step:19000, x = [3.0, 2.0], value = 0.0
step:20000, x = [3.0, 2.0], value = 0.0