• 4.1 模型构造


    模型构造

    基于Block类的模型构造方法:它让模型构造更加灵活。

    继承Block类来构造模型

    Block类是nn模块里提供的一个模型构造类,我们可以继承它来定义我们想要的模型。下面继承Block类构造本节开头提到的多层感知机。这里定义的MLP类重载了Block类的__init__函数和forward函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。

    #导包
    from mxnet import nd
    from mxnet.gluon import nn
    #MLP从Block继承
    class MLP(nn.Block):
        # 声明带有模型参数的层,这里声明了两个全连接层
        def __init__(self, **kwargs):
            # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
            # 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
            ##用super()来继承另一个类的参数,以dict的形式放在kwargs变量里面
            super(MLP, self).__init__(**kwargs)
            self.hidden = nn.Dense(256, activation='relu')  # 隐藏层
            self.output = nn.Dense(10)  # 输出层
    
        # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
        def forward(self, x):
            return self.output(self.hidden(x))
    
    #初始化net并传入输入数据X做一次前向计算
    X = nd.random.uniform(shape=(2, 20))
    net = MLP()
    net.initialize()
    net(X)
    

    net(X)会调用MLP继承自Block类的__call__函数,这个函数将调用MLP类定义的forward函数来完成前向计算。
    这里并没有将Block类命名为Layer(层)或者Model(模型)之类的名字,这是因为该类是一个可供自由组建的部件。它的子类既可以是一个层(如Gluon提供的Dense类),又可以是一个模型(如这里定义的MLP类),或者是模型的一个部分。

    Sequential类继承自Block

    Block类是一个通用的部件.
    Sequential类继承自Block类。当模型的前向计算为简单串联各个层的计算时,可以通过更加简单的方式定义模型。这正是Sequential类的目的:它提供add函数来逐一添加串联的Block子类实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算。
    实现一个与Sequential类有相同功能的MySequential类:

    #MySequential类继承Block类
    class MySequential(nn.Block):
        #继承实现构造函数
        def __init__(self, **kwargs):
            super(MySequential, self).__init__(**kwargs)
        def add(self, block):
            # block是一个Block子类实例,假设它有一个独一无二的名字。我们将它保存在Block类的
            # 成员变量_children里,其类型是OrderedDict。当MySequential实例调用
            # initialize函数时,系统会自动对_children里所有成员初始化
            #self._children = OrderedDict(),有序字典可以按字典中元素的插入顺序来输出。
            self._children[block.name] = block
    
        def forward(self, x):
            # OrderedDict保证会按照成员添加时的顺序遍历成员
            for block in self._children.values():
                #直接利用__call__做前向传播
                #out = self.forward(*args)  
                x = block(x)
            return x
    

    MySequential类来实现前面描述的MLP类,并使用随机初始化的模型做一次前向计算。

    #先申明MySequential()实例
    net = MySequential()
    #添加隐藏层
    net.add(nn.Dense(256, activation='relu'))
    #添加输出层
    net.add(nn.Dense(10))
    #初始化
    net.initialize()
    net(X)
    

    构造复杂的模型

    构造一个稍微复杂点的网络FancyMLP。在这个网络中,我们通过get_constant函数创建训练中不被迭代的参数,即常数参数。在前向计算中,除了使用创建的常数参数外,我们还使用NDArray的函数和Python的控制流,并多次调用相同的层。

    class FancyMLP(nn.Block):
        def __init__(self, **kwargs):
            super(FancyMLP, self).__init__(**kwargs)
            # 使用get_constant创建的随机权重参数不会在训练中被迭代(即常数参数)
            self.rand_weight = self.params.get_constant(
                'rand_weight', nd.random.uniform(shape=(20, 20)))
            self.dense = nn.Dense(20, activation='relu')
    
        def forward(self, x):
            x = self.dense(x)
            # 使用创建的常数参数,以及NDArray的relu函数和dot函数
            x = nd.relu(nd.dot(x, self.rand_weight.data()) + 1)
            # 复用全连接层。等价于两个全连接层共享参数
            x = self.dense(x)
            # 控制流,这里我们需要调用asscalar函数来返回标量进行比较
            while x.norm().asscalar() > 1:
                x /= 2
            if x.norm().asscalar() < 0.8:
                x *= 10
            return x.sum()
    
    net = FancyMLP()
    net.initialize()
    net(X)
    
    class NestMLP(nn.Block):
        def __init__(self, **kwargs):
            super(NestMLP, self).__init__(**kwargs)
            #嵌套调用Sequential()类
            self.net = nn.Sequential()
            #添加两个隐藏层
            self.net.add(nn.Dense(64, activation='relu'),
                         nn.Dense(32, activation='relu'))
            #添加输出层
            self.dense = nn.Dense(16, activation='relu')
        #正向传播
        def forward(self, x):
            return self.dense(self.net(x))
    
    net = nn.Sequential()
    net.add(NestMLP(), nn.Dense(20), FancyMLP())
    
    net.initialize()
    net(X)
    
  • 相关阅读:
    2019 学霸君java面试笔试题 (含面试题解析)
    2019 大众书网Java面试笔试题 (含面试题解析)
    2019 中细软java面试笔试题 (含面试题解析)
    2019 企叮咚java面试笔试题 (含面试题解析)
    js 去掉数组对象中的重复对象
    canvas霓虹雨
    nvm的安装
    socket.io 中文文档
    Nginx(三)------nginx 反向代理
    github入门到上传本地项目
  • 原文地址:https://www.cnblogs.com/strategist-614/p/14408865.html
Copyright © 2020-2023  润新知