• pytorch深度学习:非线性模型


    1. 这篇博客使用深度学习框架搭建了一个预测三次函数的模型

    2. 正则化很重要,一定要normalize,否则神经网络就是垃圾

     1 import torch
     2 from torch import nn,optim
     3 import torch.nn.functional as F
     4 from matplotlib import pyplot as plt
     5 
     6 class unLinear(nn.Module):
     7     def __init__(self,input_feature,num_hidden,output_size):
     8         super(unLinear,self).__init__()
     9         self.hidden=nn.Linear(input_feature,num_hidden)#一个层就是一个函数
    10         self.out=nn.Linear(num_hidden,output_size)#可以把层理解成函数的右值引用
    11 
    12     def forward(self,x):
    13         # x=F.relu(self.hidden(x))
    14         # x = torch.sigmoid(self.hidden(x))
    15         x=torch.tanh(self.hidden(x))
    16         x=self.out(x)
    17         return x
    18 
    19     def train(self,inputs,target,criterion,optimizer,epoches):
    20         print(inputs.size())
    21         print(target.size())
    22         loss=0
    23         for epoch in range(epoches):
    24             output = model.forward(inputs)
    25             # if epoch%1000==0:
    26             #     plt.scatter(inputs.detach().numpy(), output.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
    27             #     plt.show()
    28             loss = criterion(output, target)
    29             optimizer.zero_grad()
    30             loss.backward()
    31             optimizer.step()
    32         return self, loss
    33 
    34 model = unLinear(input_feature=1,num_hidden=20,output_size=1)
    35 x=torch.torch.arange(-2,2,0.1)
    36 y=x.pow(3)+0.1*torch.rand(x.size())
    37 # print(x)
    38 # print(y)
    39 plt.scatter(x.detach().numpy(), y.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
    40 plt.show()
    41 
    42 inputs=torch.unsqueeze(x,dim=1)
    43 target=torch.unsqueeze(y,dim=1)
    44 criterion=nn.MSELoss()
    45 optimizer = optim.SGD(model.parameters(), lr=1e-2)
    46 
    47 new_model=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000)
    48 
    49 # plt.scatter(x.numpy(),y.numpy(),c='#00CED1',s=10,alpha=0.8,label="test")
    50 # plt.show()
    51 
    52 x_predict=torch.unsqueeze(torch.arange(-2,2,0.05),dim=1)
    53 y_predict=model.forward(x_predict)
    54 # y_predict=model.forward(inputs)
    55 # print(inputs.size())
    56 # print(x_predict.size())
    57 # print(y_predict.detach().numpy())
    58 x_predict=torch.squeeze(x_predict)
    59 y_predict=torch.squeeze(y_predict)
    60 x_predict=x_predict.detach().numpy()
    61 y_predict=y_predict.detach().numpy()
    62 # print(y_predict)
    63 plt.scatter(x_predict,y_predict,s=10,alpha=0.8,label="test")
    64 plt.show()
  • 相关阅读:
    第五节、矩阵分解之LU分解
    第四节、逆矩阵与转置矩阵
    第三节、矩阵乘法
    第二节、矩阵消元(高斯消元)
    重学线代——声明篇
    第一节、方程组的几何解释
    String类
    Mycat的安装及配置
    使用InfluxDB、cAdvisor、Grafana监控服务器性能
    Rancher的使用
  • 原文地址:https://www.cnblogs.com/St-Lovaer/p/13696295.html
Copyright © 2020-2023  润新知