• 从零开始学习MXnet(三)之Model和Module


      在我们在MXnet中定义好symbol、写好dataiter并且准备好data之后,就可以开开心的去训练了。一般训练一个网络有两种常用的策略,基于model的和基于module的。今天,我想谈一谈他们的使用。

    一、Model

      按照老规矩,直接从官方文档里面拿出来的代码看一下:

      

     # configure a two layer neuralnetwork
        data = mx.symbol.Variable('data')
        fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
        act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu')
        fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
        softmax = mx.symbol.SoftmaxOutput(fc2, name='sm')
    # create a model using sklearn-style two-step way
    #创建一个model 
       model = mx.model.FeedForward(
             softmax,
             num_epoch=num_epoch,
             learning_rate=0.01)
    #开始训练
        model.fit(X=data_set)
    

      具体的API参照http://mxnet.io/api/python/model.html。

      然后呢,model这部分就说完了。。。之所以这么快主要有两个原因:

        1.确实东西不多,一般都是查一查文档就可以了。

        2.model的可定制性不强,一般我们是很少使用的,常用的还是module。

    二、Module

      Module真的是一个很棒的东西,虽然深入了解后,你会觉得“哇,好厉害,但是感觉没什么鸟用呢”这种想法。。实际上我就有过,现在回想起来,从代码的设计和使用的角度来讲,Module确实是一个非常好的东西,它可以为我们的网络计算提高了中级、高级的接口,这样一来,就可以有很多的个性化配置让我们自己来做了。

      Module有四种状态:

        1.初始化状态,就是显存还没有被分配,基本上啥都没做的状态。

        2.binded,在把data和label的shape传到Bind函数里并且执行之后,显存就分配好了,可以准备好计算能力。

        3.参数初始化。就是初始化参数

        3.Optimizer installed 。就是传入SGD,Adam这种optimuzer中去进行训练 

     先上一个简单的代码:

      

    import mxnet as mx
    
        # construct a simple MLP
        data = mx.symbol.Variable('data')
        fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
        act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
        fc2  = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
        act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
        fc3  = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
        out  = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
    
        # construct the module
        mod = mx.mod.Module(out)
       
         mod.bind(data_shapes=train_dataiter.provide_data,
             label_shapes=train_dataiter.provide_label)
       
         mod.init_params()
         mod.fit(train_dataiter, eval_data=eval_dataiter,
                optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
                num_epoch=n_epoch)
    

      分析一下:首先是定义了一个简单的MLP,symbol的名字就叫做out,然后可以直接用mx.mod.Module来创建一个mod。之后mod.bind的操作是在显卡上分配所需的显存,所以我们需要把data_shapehe label_shape传递给他,然后初始化网络的参数,再然后就是mod.fit开始训练了。这里补充一下。fit这个函数我们已经看见两次了,实际上它是一个集成的功能,mod.fit()实际上它内部的核心代码是这样的:

      

    for epoch in range(begin_epoch, num_epoch): 
                 tic = time.time() 
                 eval_metric.reset() 
                 for nbatch, data_batch in enumerate(train_data): 
                     if monitor is not None: 
                         monitor.tic() 
                     self.forward_backward(data_batch) #网络进行一次前向传播和后向传播
                     self.update()  #更新参数
                     self.update_metric(eval_metric, data_batch.label) #更新metric 
     
     
                     if monitor is not None: 
                         monitor.toc_print() 
     
    
                     if batch_end_callback is not None: 
                         batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, 
                                                          eval_metric=eval_metric, 
                                                          locals=locals()) 
                         for callback in _as_list(batch_end_callback): 
                             callback(batch_end_params) 
    

      正是因为module里面我们可以使用很多intermediate的interface,所以可以做出很多改进,举个最简单的例子:如果我们的训练网络是大小可变怎么办? 我们可以实现一个mutumodule,基本上就是,每次data的shape变了的时候,我们就重新bind一下symbol,这样训练就可以照常进行了。

      

      总结:实际上学一个框架的关键还是使用它,要说诀窍的话也就是多看看源码和文档了,我写这些博客的目的,一是为了记录一些东西,二是让后来者少走一些弯路。所以有些东西不会说的很全。。

      

  • 相关阅读:
    HDU-1102 Constructing Roads ( 最小生成树 )
    POJ-1287 Networking ( 最小生成树 )
    HDU-1272 小希的迷宫 ( 并查集 )
    Java基本数据类型、关键字
    观察者模式
    Android系统启动过程分析
    Activity启动过程源码分析(Android 8.0)
    Okhttp解析—Okhttp概览
    Okhttp解析—Interceptor详解
    Okhttp源码分析--基本使用流程分析
  • 原文地址:https://www.cnblogs.com/daihengchen/p/6506386.html
Copyright © 2020-2023  润新知