求参数w进行求解梯度有两种方式1.
mse.backward()
w.grad
方式2.
torch.autograd.grad(mse,[w])
#损失函数的梯度
import torch import torch.nn.functional as F x=torch.ones(1) w=torch.full([1],2) mse=F.mse_loss(torch.ones(1),x*w) w.requires_grad_() mse=F.mse_loss(torch.ones(1),x*w) #第一种方式 mse.backward() w.grad 第二种方式 torch.autograd.grad(mse,[w])
#计算softmax函数 import torch import torch.nn.functional as F a=torch.rand(3) p=F.softmax(a,dim=0) a.requires_grad_() p=F.softmax(a,dim=0) torch.autograd.grad(p[2],[a],retain_graph=True)
import torch import torch.nn.functional as F x=torch.randn(1,10) w=torch.randn(2,10,requires_grad=True) o=torch.sigmoid(x@w.t()) o.shape loss=F.mse_loss(torch.ones(1,2),o) loss.backward() w.grad
#两层函数求解梯度值 import torch import torch.nn.functional as F x=torch.tensor(1.) w1=torch.tensor(2.,requires_grad=True) b1=torch.tensor(1.) w2=torch.tensor(2.,requires_grad=True) b2=torch.tensor(1.) y1=x*w1+b1 y2=y1*w2+b2 day2_dy1=torch.autograd.grad(y2,[y1],retain_graph=True)[0] day1_dw1=torch.autograd.grad(y1,[w1],retain_graph=True)[0] day2_dw1=torch.autograd.grad(y2,[w1],retain_graph=True)[0] day2_dy1*day1_dw1 tensor(2.) day2_dw1 tensor(2.)
import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D def himmelblau(x): return (x[0]**2+x[1]-11)**2+(x[0]+x[1]**2-7)**2 x=np.arange(-6,6,0.1) y=np.arange(-6,6,0.1) print('x,y range',x.shape,y.shape) X,Y=np.meshgrid(x,y) print('X,Y maps:',X.shape,Y.shape) Z=himmelblau([X,Y]) fig=plt.figure('himmelblau') ax=fig.gca(projection='3d') ax.plot_surface(X,Y,Z) ax.view_init(60,-30) ax.set_xlabel('x') ax.set_ylabel('y') plt.show() x=torch.tensor([0.,0.],requires_grad=True) optimizer=torch.optim.Adam([x],lr=0.001) #lr=le-3表示学习率为0.001 for step in range(20000): pred=himmelblau(x) optimizer.zero_grad() pred.backward() optimizer.step() if step%2000==0: print('step {}:x={},f(x)={}'.format(step,x.tolist(),pred.item()))
step 0:x=[0.0009999999310821295, 0.0009999999310821295],f(x)=170.0 step 2000:x=[2.3331806659698486, 1.9540694952011108],f(x)=13.730916023254395 step 4000:x=[2.9820079803466797, 2.0270984172821045],f(x)=0.014858869835734367 step 6000:x=[2.999983549118042, 2.0000221729278564],f(x)=1.1074007488787174e-08 step 8000:x=[2.9999938011169434, 2.0000083446502686],f(x)=1.5572823031106964e-09 step 10000:x=[2.999997854232788, 2.000002861022949],f(x)=1.8189894035458565e-10 step 12000:x=[2.9999992847442627, 2.0000009536743164],f(x)=1.6370904631912708e-11 step 14000:x=[2.999999761581421, 2.000000238418579],f(x)=1.8189894035458565e-12 step 16000:x=[3.0, 2.0],f(x)=0.0 step 18000:x=[3.0, 2.0],f(x)=0.0