1. 这篇博客使用深度学习框架搭建了一个预测三次函数的模型
2. 正则化很重要,一定要normalize,否则神经网络就是垃圾
1 import torch 2 from torch import nn,optim 3 import torch.nn.functional as F 4 from matplotlib import pyplot as plt 5 6 class unLinear(nn.Module): 7 def __init__(self,input_feature,num_hidden,output_size): 8 super(unLinear,self).__init__() 9 self.hidden=nn.Linear(input_feature,num_hidden)#一个层就是一个函数 10 self.out=nn.Linear(num_hidden,output_size)#可以把层理解成函数的右值引用 11 12 def forward(self,x): 13 # x=F.relu(self.hidden(x)) 14 # x = torch.sigmoid(self.hidden(x)) 15 x=torch.tanh(self.hidden(x)) 16 x=self.out(x) 17 return x 18 19 def train(self,inputs,target,criterion,optimizer,epoches): 20 print(inputs.size()) 21 print(target.size()) 22 loss=0 23 for epoch in range(epoches): 24 output = model.forward(inputs) 25 # if epoch%1000==0: 26 # plt.scatter(inputs.detach().numpy(), output.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test") 27 # plt.show() 28 loss = criterion(output, target) 29 optimizer.zero_grad() 30 loss.backward() 31 optimizer.step() 32 return self, loss 33 34 model = unLinear(input_feature=1,num_hidden=20,output_size=1) 35 x=torch.torch.arange(-2,2,0.1) 36 y=x.pow(3)+0.1*torch.rand(x.size()) 37 # print(x) 38 # print(y) 39 plt.scatter(x.detach().numpy(), y.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test") 40 plt.show() 41 42 inputs=torch.unsqueeze(x,dim=1) 43 target=torch.unsqueeze(y,dim=1) 44 criterion=nn.MSELoss() 45 optimizer = optim.SGD(model.parameters(), lr=1e-2) 46 47 new_model=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000) 48 49 # plt.scatter(x.numpy(),y.numpy(),c='#00CED1',s=10,alpha=0.8,label="test") 50 # plt.show() 51 52 x_predict=torch.unsqueeze(torch.arange(-2,2,0.05),dim=1) 53 y_predict=model.forward(x_predict) 54 # y_predict=model.forward(inputs) 55 # print(inputs.size()) 56 # print(x_predict.size()) 57 # print(y_predict.detach().numpy()) 58 x_predict=torch.squeeze(x_predict) 59 y_predict=torch.squeeze(y_predict) 60 x_predict=x_predict.detach().numpy() 61 y_predict=y_predict.detach().numpy() 62 # print(y_predict) 63 plt.scatter(x_predict,y_predict,s=10,alpha=0.8,label="test") 64 plt.show()