• 【501】pytorch教程之nn.Module类详解


    参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型

      pytorch中对于一般的序列模型,直接使用torch.nn.Sequential类及可以实现,这点类似于keras,但是更多的时候面对复杂的模型,比如:多输入多输出、多分支模型、跨层连接模型、带有自定义层的模型等,就需要自己来定义一个模型了。本文将详细说明如何让使用Mudule类来自定义一个模型。

      pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的。

      我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。但有一些注意技巧:

    • 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;
    • 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
    • forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

      所有放在构造函数__init__里面的层的都是这个模型的“固有属性”。

      官方例子

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Model(nn.Module):
        def __init__(self):
            # 固定内容
            super(Model, self).__init__()
    
            # 定义相关的函数
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
            # 构建模型结构,可以使用F函数内容,其他调用__init__里面的函数
            x = F.relu(self.conv1(x))
    
            # 返回最终的结果
            return F.relu(self.conv2(x))
    

    ☀☀☀<< 举例 >>☀☀☀

      代码一:

    import torch
    
    N, D_in, H, D_out = 64, 1000, 100, 10
     
    torch.manual_seed(1)
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
     
    #-----changed part-----#
    model = torch.nn.Sequential(
        torch.nn.Linear(D_in, H),
        torch.nn.ReLU(),
        torch.nn.Linear(H, D_out),
    )
    #-----changed part-----#
    
    loss_fn = torch.nn.MSELoss(reduction='sum')
    learning_rate = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for t in range(500):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if t % 100 == 99:
            print(t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

      代码二:

    import torch
    
    N, D_in, H, D_out = 64, 1000, 100, 10
    
    torch.manual_seed(1)
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
    
    #-----changed part-----#
    class Alex_nn(nn.Module):
        def __init__(self):
            super(Alex_nn, self).__init__()
            self.h1 = torch.nn.Linear(D_in, H)
            self.h1_relu = torch.nn.ReLU()
            self.output = torch.nn.Linear(H, D_out)
            
        def forward(self, x):
            h1 = self.h1(x)
            h1_relu = self.h1_relu(h1)
            output = self.output(h1_relu)
            return output
            
    model = Alex_nn()
    #-----changed part-----#
    
    loss_fn = torch.nn.MSELoss(reduction='sum')
    learning_rate = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for t in range(500):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if t % 100 == 99:
            print(t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

      代码三:

    import torch
    
    N, D_in, H, D_out = 64, 1000, 100, 10
    
    torch.manual_seed(1)
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
    
    #-----changed part-----#
    class Alex_nn(nn.Module):
        def __init__(self, D_in_, H_, D_out_):
            super(Alex_nn, self).__init__()
            self.D_in = D_in_
            self.H = H_
            self.D_out = D_out_
            
            self.h1 = torch.nn.Linear(self.D_in, self.H)
            self.h1_relu = torch.nn.ReLU()
            self.output = torch.nn.Linear(self.H, self.D_out)
            
        def forward(self, x):
            h1 = self.h1(x)
            h1_relu = self.h1_relu(h1)
            output = self.output(h1_relu)
            return output
            
    model = Alex_nn(D_in, H, D_out)
    #-----changed part-----#
    
    loss_fn = torch.nn.MSELoss(reduction='sum')
    learning_rate = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for t in range(500):
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if t % 100 == 99:
            print(t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

      

  • 相关阅读:
    Android中的sp与wp
    MTK
    linux kernel文件系统启动部分
    Java项目构建基础之统一结果
    线程和线程池的学习
    SpringBoot 中MyBatis的配置
    MyBatis中使用Map传参——返回值也是Map
    OAuth2的学习
    Java 跨域问题
    Spring Cloud 中的 eureka.instance.appname和spring.application.name 意思
  • 原文地址:https://www.cnblogs.com/alex-bn-lee/p/14092666.html
Copyright © 2020-2023  润新知