• 3.神经网络的保存、神经网络提取的2 ways


     1 """
     2 torch: 0.4
     3 matplotlib
     4 神经网络的保存 
     5 神经网络提取的2 ways
     6 """
     7 import torch
     8 import matplotlib.pyplot as plt
     9 
    10 
    11 
    12 # fake data
    13 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
    14 y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
    15 
    16 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
    17 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
    18 
    19 # save net1
    20 def save():
    21     # 建立网络实例net1
    22     net1 = torch.nn.Sequential(
    23         torch.nn.Linear(1, 10),
    24         torch.nn.ReLU(),
    25         torch.nn.Linear(10, 1)
    26     )
    27     optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    28     loss_func = torch.nn.MSELoss()
    29     #训练
    30     for t in range(100):
    31         prediction = net1(x)
    32         loss = loss_func(prediction, y)
    33         optimizer.zero_grad()
    34         loss.backward()
    35         optimizer.step()
    36 
    37     # plot result
    38     plt.figure(1, figsize=(10, 3))
    39     plt.subplot(131)
    40     plt.title('Net1')
    41     plt.scatter(x.data.numpy(), y.data.numpy())
    42     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    43 
    44     # 2 ways to save the net
    45     torch.save(net1, 'net.pkl')  # save entire net
    46     torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters
    47 
    48 
    49 def restore_net():
    50     # restore entire net1 to net2
    51     net2 = torch.load('net.pkl')
    52     prediction = net2(x)
    53 
    54     # plot result
    55     plt.subplot(132)
    56     plt.title('Net2')
    57     plt.scatter(x.data.numpy(), y.data.numpy())
    58     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    59 
    60 
    61 def restore_params():
    62     # restore only the parameters in net1 to net3
    63     net3 = torch.nn.Sequential(
    64         torch.nn.Linear(1, 10),
    65         torch.nn.ReLU(),
    66         torch.nn.Linear(10, 1)
    67     )
    68 
    69     # copy net1's parameters into net3
    70     net3.load_state_dict(torch.load('net_params.pkl'))
    71     prediction = net3(x)
    72 
    73     # plot result
    74     plt.subplot(133)
    75     plt.title('Net3')
    76     plt.scatter(x.data.numpy(), y.data.numpy())
    77     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    78     plt.show()
    79 
    80 # save net1
    81 save()
    82 
    83 # restore entire net (may slow)
    84 restore_net()
    85 
    86 # restore only the net parameters
    87 restore_params()
  • 相关阅读:
    MySQL: Connection Refused,调整 mysql.ini中的 max_connections
    Eclipse: Difference between clean, build and publish
    Enterprise Integration Patterns
    圆上两点的解题思路(用户需求分析的隐喻)
    Activiti解析.bpmn文件获得User Task节点的CandidateUsers特性的值
    Activiti的BPMN演示工具
    Activiti For Eclipse(Mars)插件配置
    TortoiseSvn/Git的WaterEffect
    Activiti启动某个流程失败,页面报500
    eclipse webproject activiti
  • 原文地址:https://www.cnblogs.com/xuechengmeigui/p/12388514.html
Copyright © 2020-2023  润新知