• pytorch搭建网络,保存参数,恢复参数


    这是看过莫凡python的学习笔记。

    搭建网络,两种方式

    (1)建立Sequential对象

    import torch
    net = torch.nn.Sequential(
                torch.nn.Linear(2,10),
                torch.nn.ReLU(),
                torch.nn.Linear(10,2))

    输出网络结构

    Sequential(
      (0): Linear(in_features=2, out_features=10, bias=True)
      (1): ReLU()
      (2): Linear(in_features=10, out_features=2, bias=True)
    )

    (2)建立网络类,继承torch.nn.module

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            self.hidden = torch.nn.Linear(2,10)
            self.predict = torch.nn.Linear(10,2)
        def forward(self,x):
            x = F.relu(self.hidden(x))
            x = self.predict(x)
            return x

    输出和上面基本一样,略微不同

    Net(
      (hidden): Linear(in_features=2, out_features=10, bias=True)
      (predict): Linear(in_features=10, out_features=2, bias=True)
    )

    保存模型,两种方式

    (1)保存整个网络,及网络参数

    torch.save(net,'net.pkl')

    (2)只保存网络参数

    torch.save(net.state_dict(),'net_params.pkl')

    恢复模型,两种方式

    (1)加载整个网络,及参数

    net2 = torch.load('net.pkl')

    (2)加载参数,但需实现网络

    net3 = torch.nn.Sequential(
                torch.nn.Linear(2,10),
                torch.nn.ReLU(),
                torch.nn.Linear(10,2))
    net3.load_state_dict(torch.load('net_params.pkl'))
  • 相关阅读:
    mvn tomcat7:help的14个命令
    leetcode Next Permutation
    leetcode Permutation
    java HashMap
    单链表反转(递归和非递归) (Java)
    java数据类型
    Multiply Strings 大数相乘 java
    SQL中如何使用UPDATE语句进行联表更新(转)
    循环建立索引
    js 跨域访问
  • 原文地址:https://www.cnblogs.com/wzyuan/p/9458008.html
Copyright © 2020-2023  润新知