• 深度学习之构造模型,访问模型参数——2020.3.11


    今天主要学习了利用torch中的nn模块定义Module类,下面的代码包含对于模型类的构建以及参数访问,简便的可以使用‘net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())’构建模型,默认进行初始化。

    # 3.1 构造模型
    import torch 
    from torch import nn #module类是nn模块里提供的一个模型构造类
    
    # 定义MLP类
    class MLP(nn.Module):
        def __init__(self, **kwargs):
            super(MLP, self).__init__(**kwargs) #重载MLP类
            self.hidden = nn.Linear(784, 256)
            self.act = nn.ReLU()
            self.output = nn.Linear(256, 10)
        
        # 定义前向计算,反向传播函数可通过生成反向传播所需的backward函数
        def forward(self, x):
            a = self.act(self.hidden(x))
            return self.output(a)
    
    # 初始化net并传入输入数据x,做前向计算
    X = torch.rand(2, 784)
    net = MLP()
    net(X)
    

    # 4.12 module 的子类
    class MySquential(nn.Module):
        from collections import OrderedDict
        def __init__(self, *args):
            super(MySquential, self).__init__()
            if len(args) == 1 and isinstance(args[0], OrderedDict):
                for key, module in args[0].items():
                    self.add_module(key,module)
            else:
                for idx, module in enumerate(args):
                    self.add_module(str(idx), module)
        
        def forward(self, input):
            for module in self._modules.values():
                input = module(input)
            return input
    
    net = MySquential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10),)
    print(net)
    net(X)
    

    输出结果

    # ModuleLise 类
    net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
    net.append(nn.Linear(256, 10))
    print(net[0]) #使用Listd的索引访问
    print(net)
    

    输出结果

    # ModuleDict类
    net = nn.ModuleDict({'linear' : nn.Linear(784, 256), 'act' : nn.ReLU(),})
    net['output'] = nn.Linear(256, 10)
    print(net['linear']) # 访问
    print(net.output)
    print(net)
    
    # 构造复杂模型
    class FancyMLP(nn.Module):
        def __init__(self, **kwargs):
            super(FancyMLP, self).__init__(**kwargs)
            
            self.rand_weight = torch.rand((20, 20),requires_grad=False)
            self.linear = nn.Linear(20, 20)
            
        def forward(self, x):
            x = self.linear(x)
            x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
            
            x = self.linear(x)
            while x.norm().item() > 1:
                x /= 2
            if x.norm().item() < 0.0:
                x *= 10
            return x.sum()
            
    
    X = torch.rand(2, 20)
    net = FancyMLP()
    print(net)
    net(X)
    

    输出结果

    # 嵌套调用FancyMLP和Sequential类
    class NestMLP(nn.Module):
        def __init__(self, **kwargs):
            super(NestMLP, self).__init__(**kwargs)
            self.net = nn.Sequential(nn.Linear(40,30), nn.ReLU())
        
        def forward(self, x):
            return self.net(x)
    
    net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())
    
    X = torch.rand(2, 40)
    print(net)
    net(X)
    

    输出结果

    # 4.2 模型参数的访问、初始化和共享
    import torch
    from torch import nn
    from torch.nn import init
    
    net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) 
    
    print(net)
    X = torch.rand(2, 4)
    Y = net(X).sum()
    

    输出结果

    # 访问模型参数
    print(type(net.named_parameters()))
    for name, param in net.named_parameters():
        print(name, param.size())
    

    输出结果

  • 相关阅读:
    HDU 3416 Marriage Match IV(SPFA+最大流)
    asp.net一些很酷很实用的.Net技巧
    asp.net部分控件使用和开发技巧总结
    ASP_NET Global_asax详解
    asp.net 多字段模糊查询代码
    Asp.net中防止用户多次登录的方法
    SQL Server 事务、异常和游标
    有关Cookie
    asp.net 连接sql server 2005 用户 'sa' 登录失败。asp.net开发第一步连接的细节问题
    asp.net生成高质量缩略图通用函数(c#代码),支持多种生成方式
  • 原文地址:https://www.cnblogs.com/somedayLi/p/12461476.html
Copyright © 2020-2023  润新知