• (原)torch的apply函数


    转载请注明出处:

    http://www.cnblogs.com/darkknightzh/p/6221633.html

    torch中的apply函数通过可以不断遍历model的各个模块。实际上其使用的是深度优先算法。

    其具体代码如下所示(代码见torch/install/share/lua/5.1/nn/Module.lua):

    -- Run a callback (called with the module as an argument) in preorder over this
    -- module and its children.
    --
    function Module:apply(callback)
        callback(self)
    
        if self.modules then
            for _, module in ipairs(self.modules) do
                module:apply(callback)
            end
        end
    end

    可见,apply递归调用自身,直到不存在模块为止(这样说不太合理)。

    如下所示的测试代码:

    require "dpnn"
    
    function createModel()
       local net = nn.Sequential()
    
       net:add(nn.SpatialConvolutionMM(3, 64, 7, 7, 2, 2, 3, 3))
       net:add(nn.SpatialBatchNormalization(64))
       net:add(nn.ReLU())
       net:add(nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1))
    
       net:add(nn.Inception{
         inputSize = 192,
         kernelSize = {3, 5},
         kernelStride = {1, 1},
         outputSize = {128, 32},
         reduceSize = {96, 16, 32, 64},
         pool = nn.SpatialMaxPooling(3, 3, 1, 1, 1, 1),
         batchNorm = true
       })
    
       net:add(nn.Inception{
         inputSize = 256,
         kernelSize = {3, 5},
         kernelStride = {1, 1},
         outputSize = {128, 64},
         reduceSize = {96, 32, 64, 64},
         pool = nn.SpatialLPPooling(256, 2, 3, 3, 1, 1),
         batchNorm = false
       })
    
       net:add(nn.SpatialAveragePooling(7, 7))
       net:add(nn.View(320))
       net:add(nn.Linear(320, 128))
       net:add(nn.Normalize(2))
    
       return net
    end
    
    
    torch.setdefaulttensortype('torch.FloatTensor')
    
    local model = createModel()
    
    --print(model)
    tt = 0
    model:apply(function(module)
        tt = tt + 1
        print(tt, module)
    end)

    其输出结果为:

    1	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> output]
      (1): nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3)
      (2): nn.SpatialBatchNormalization
      (3): nn.ReLU
      (4): nn.SpatialMaxPooling(3x3, 2,2, 1,1)
      (5): nn.Inception @ nn.DepthConcat {
        input
          |`-> (1): nn.Sequential {
          |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
          |      (1): nn.SpatialConvolution(192 -> 96, 1x1)
          |      (2): nn.SpatialBatchNormalization
          |      (3): nn.ReLU
          |      (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
          |      (5): nn.SpatialBatchNormalization
          |      (6): nn.ReLU
          |    }
          |`-> (2): nn.Sequential {
          |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
          |      (1): nn.SpatialConvolution(192 -> 16, 1x1)
          |      (2): nn.SpatialBatchNormalization
          |      (3): nn.ReLU
          |      (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
          |      (5): nn.SpatialBatchNormalization
          |      (6): nn.ReLU
          |    }
          |`-> (3): nn.Sequential {
          |      [input -> (1) -> (2) -> (3) -> (4) -> output]
          |      (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
          |      (2): nn.SpatialConvolution(192 -> 32, 1x1)
          |      (3): nn.SpatialBatchNormalization
          |      (4): nn.ReLU
          |    }
          |`-> (4): nn.Sequential {
                 [input -> (1) -> (2) -> (3) -> output]
                 (1): nn.SpatialConvolution(192 -> 64, 1x1)
                 (2): nn.SpatialBatchNormalization
                 (3): nn.ReLU
               }
           ... -> output
      }
      (6): nn.Inception @ nn.DepthConcat {
        input
          |`-> (1): nn.Sequential {
          |      [input -> (1) -> (2) -> (3) -> (4) -> output]
          |      (1): nn.SpatialConvolution(256 -> 96, 1x1)
          |      (2): nn.ReLU
          |      (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
          |      (4): nn.ReLU
          |    }
          |`-> (2): nn.Sequential {
          |      [input -> (1) -> (2) -> (3) -> (4) -> output]
          |      (1): nn.SpatialConvolution(256 -> 32, 1x1)
          |      (2): nn.ReLU
          |      (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
          |      (4): nn.ReLU
          |    }
          |`-> (3): nn.Sequential {
          |      [input -> (1) -> (2) -> (3) -> output]
          |      (1): nn.Sequential {
          |        [input -> (1) -> (2) -> (3) -> (4) -> output]
          |        (1): nn.Square
          |        (2): nn.SpatialAveragePooling(3x3, 1,1)
          |        (3): nn.MulConstant
          |        (4): nn.Sqrt
          |      }
          |      (2): nn.SpatialConvolution(256 -> 64, 1x1)
          |      (3): nn.ReLU
          |    }
          |`-> (4): nn.Sequential {
                 [input -> (1) -> (2) -> output]
                 (1): nn.SpatialConvolution(256 -> 64, 1x1)
                 (2): nn.ReLU
               }
           ... -> output
      }
      (7): nn.SpatialAveragePooling(7x7, 1,1)
      (8): nn.View(320)
      (9): nn.Linear(320 -> 128)
      (10): nn.Normalize(2)
    }
    2	nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3)
    3	nn.SpatialBatchNormalization
    4	nn.ReLU
    5	nn.SpatialMaxPooling(3x3, 2,2, 1,1)
    6	nn.Inception @ nn.DepthConcat {
      input
        |`-> (1): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
        |      (1): nn.SpatialConvolution(192 -> 96, 1x1)
        |      (2): nn.SpatialBatchNormalization
        |      (3): nn.ReLU
        |      (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
        |      (5): nn.SpatialBatchNormalization
        |      (6): nn.ReLU
        |    }
        |`-> (2): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
        |      (1): nn.SpatialConvolution(192 -> 16, 1x1)
        |      (2): nn.SpatialBatchNormalization
        |      (3): nn.ReLU
        |      (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
        |      (5): nn.SpatialBatchNormalization
        |      (6): nn.ReLU
        |    }
        |`-> (3): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> output]
        |      (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
        |      (2): nn.SpatialConvolution(192 -> 32, 1x1)
        |      (3): nn.SpatialBatchNormalization
        |      (4): nn.ReLU
        |    }
        |`-> (4): nn.Sequential {
               [input -> (1) -> (2) -> (3) -> output]
               (1): nn.SpatialConvolution(192 -> 64, 1x1)
               (2): nn.SpatialBatchNormalization
               (3): nn.ReLU
             }
         ... -> output
    }
    7	nn.DepthConcat {
      input
        |`-> (1): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
        |      (1): nn.SpatialConvolution(192 -> 96, 1x1)
        |      (2): nn.SpatialBatchNormalization
        |      (3): nn.ReLU
        |      (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
        |      (5): nn.SpatialBatchNormalization
        |      (6): nn.ReLU
        |    }
        |`-> (2): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
        |      (1): nn.SpatialConvolution(192 -> 16, 1x1)
        |      (2): nn.SpatialBatchNormalization
        |      (3): nn.ReLU
        |      (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
        |      (5): nn.SpatialBatchNormalization
        |      (6): nn.ReLU
        |    }
        |`-> (3): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> output]
        |      (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
        |      (2): nn.SpatialConvolution(192 -> 32, 1x1)
        |      (3): nn.SpatialBatchNormalization
        |      (4): nn.ReLU
        |    }
        |`-> (4): nn.Sequential {
               [input -> (1) -> (2) -> (3) -> output]
               (1): nn.SpatialConvolution(192 -> 64, 1x1)
               (2): nn.SpatialBatchNormalization
               (3): nn.ReLU
             }
         ... -> output
    }
    8	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
      (1): nn.SpatialConvolution(192 -> 96, 1x1)
      (2): nn.SpatialBatchNormalization
      (3): nn.ReLU
      (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
      (5): nn.SpatialBatchNormalization
      (6): nn.ReLU
    }
    9	nn.SpatialConvolution(192 -> 96, 1x1)
    10	nn.SpatialBatchNormalization
    11	nn.ReLU
    12	nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
    13	nn.SpatialBatchNormalization
    14	nn.ReLU
    15	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
      (1): nn.SpatialConvolution(192 -> 16, 1x1)
      (2): nn.SpatialBatchNormalization
      (3): nn.ReLU
      (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
      (5): nn.SpatialBatchNormalization
      (6): nn.ReLU
    }
    16	nn.SpatialConvolution(192 -> 16, 1x1)
    17	nn.SpatialBatchNormalization
    18	nn.ReLU
    19	nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
    20	nn.SpatialBatchNormalization
    21	nn.ReLU
    22	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> (4) -> output]
      (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
      (2): nn.SpatialConvolution(192 -> 32, 1x1)
      (3): nn.SpatialBatchNormalization
      (4): nn.ReLU
    }
    23	nn.SpatialMaxPooling(3x3, 1,1, 1,1)
    24	nn.SpatialConvolution(192 -> 32, 1x1)
    25	nn.SpatialBatchNormalization
    26	nn.ReLU
    27	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> output]
      (1): nn.SpatialConvolution(192 -> 64, 1x1)
      (2): nn.SpatialBatchNormalization
      (3): nn.ReLU
    }
    28	nn.SpatialConvolution(192 -> 64, 1x1)
    29	nn.SpatialBatchNormalization
    30	nn.ReLU
    31	nn.Inception @ nn.DepthConcat {
      input
        |`-> (1): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> output]
        |      (1): nn.SpatialConvolution(256 -> 96, 1x1)
        |      (2): nn.ReLU
        |      (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
        |      (4): nn.ReLU
        |    }
        |`-> (2): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> output]
        |      (1): nn.SpatialConvolution(256 -> 32, 1x1)
        |      (2): nn.ReLU
        |      (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
        |      (4): nn.ReLU
        |    }
        |`-> (3): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> output]
        |      (1): nn.Sequential {
        |        [input -> (1) -> (2) -> (3) -> (4) -> output]
        |        (1): nn.Square
        |        (2): nn.SpatialAveragePooling(3x3, 1,1)
        |        (3): nn.MulConstant
        |        (4): nn.Sqrt
        |      }
        |      (2): nn.SpatialConvolution(256 -> 64, 1x1)
        |      (3): nn.ReLU
        |    }
        |`-> (4): nn.Sequential {
               [input -> (1) -> (2) -> output]
               (1): nn.SpatialConvolution(256 -> 64, 1x1)
               (2): nn.ReLU
             }
         ... -> output
    }
    32	nn.DepthConcat {
      input
        |`-> (1): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> output]
        |      (1): nn.SpatialConvolution(256 -> 96, 1x1)
        |      (2): nn.ReLU
        |      (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
        |      (4): nn.ReLU
        |    }
        |`-> (2): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> output]
        |      (1): nn.SpatialConvolution(256 -> 32, 1x1)
        |      (2): nn.ReLU
        |      (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
        |      (4): nn.ReLU
        |    }
        |`-> (3): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> output]
        |      (1): nn.Sequential {
        |        [input -> (1) -> (2) -> (3) -> (4) -> output]
        |        (1): nn.Square
        |        (2): nn.SpatialAveragePooling(3x3, 1,1)
        |        (3): nn.MulConstant
        |        (4): nn.Sqrt
        |      }
        |      (2): nn.SpatialConvolution(256 -> 64, 1x1)
        |      (3): nn.ReLU
        |    }
        |`-> (4): nn.Sequential {
               [input -> (1) -> (2) -> output]
               (1): nn.SpatialConvolution(256 -> 64, 1x1)
               (2): nn.ReLU
             }
         ... -> output
    }
    33	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> (4) -> output]
      (1): nn.SpatialConvolution(256 -> 96, 1x1)
      (2): nn.ReLU
      (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
      (4): nn.ReLU
    }
    34	nn.SpatialConvolution(256 -> 96, 1x1)
    35	nn.ReLU
    36	nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
    37	nn.ReLU
    38	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> (4) -> output]
      (1): nn.SpatialConvolution(256 -> 32, 1x1)
      (2): nn.ReLU
      (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
      (4): nn.ReLU
    }
    39	nn.SpatialConvolution(256 -> 32, 1x1)
    40	nn.ReLU
    41	nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
    42	nn.ReLU
    43	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> output]
      (1): nn.Sequential {
        [input -> (1) -> (2) -> (3) -> (4) -> output]
        (1): nn.Square
        (2): nn.SpatialAveragePooling(3x3, 1,1)
        (3): nn.MulConstant
        (4): nn.Sqrt
      }
      (2): nn.SpatialConvolution(256 -> 64, 1x1)
      (3): nn.ReLU
    }
    44	nn.Sequential {
      [input -> (1) -> (2) -> (3) -> (4) -> output]
      (1): nn.Square
      (2): nn.SpatialAveragePooling(3x3, 1,1)
      (3): nn.MulConstant
      (4): nn.Sqrt
    }
    45	nn.Square
    46	nn.SpatialAveragePooling(3x3, 1,1)
    47	nn.MulConstant
    48	nn.Sqrt
    49	nn.SpatialConvolution(256 -> 64, 1x1)
    50	nn.ReLU
    51	nn.Sequential {
      [input -> (1) -> (2) -> output]
      (1): nn.SpatialConvolution(256 -> 64, 1x1)
      (2): nn.ReLU
    }
    52	nn.SpatialConvolution(256 -> 64, 1x1)
    53	nn.ReLU
    54	nn.SpatialAveragePooling(7x7, 1,1)
    55	nn.View(320)
    56	nn.Linear(320 -> 128)
    57	nn.Normalize(2)
    View Code

    由上述结果可以看出,使用apply后,第1次输出整个模型,此处为最顶层的。

    第2-5次输出:

    2       nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3)

    3       nn.SpatialBatchNormalization

    4       nn.ReLU

    5       nn.SpatialMaxPooling(3x3, 2,2, 1,1)

    为Inception之前的几个层。

    第6次为nn.Inception @ nn.DepthConcat,第7次为nn.DepthConcat。此处是第一个Inceptioin层。

    第8次为Inception的第一个nn.Sequential,第9-14次为该层的具体层。此时已经到了第一个最底层。

    第15次为Inception的第二个nn.Sequential,第16-21次为该层的具体层。此时已经到了第二个最底层。

    第22次为Inception的第三个nn.Sequential,第23-26次为该层的具体层。此时已经到了第三个最底层。

    第27次为Inception的第四个nn.Sequential,第28-30次为该层的具体层。此时已经到了第四个最底层。

    至此,第一个Inception层通过深度优先的方式遍历完毕。

    第31次为nn.Inception @ nn.DepthConcat,第32次为nn.DepthConcat。此处是第二个Inceptioin层(注意,为了区分第一个Inception和第二个Inception层,这两个层具体结构不完全一样)。

    第33次为Inception的第一个nn.Sequential,第34-37次为该层的具体层。此时已经到了第一个最底层。

    第38次为Inception的第二个nn.Sequential,第39-42次为该层的具体层。此时已经到了第二个最底层。

    第43次为Inception的第三个nn.Sequential。

    第44次为第三个nn.Sequential的第一个小module(也是一个nn.Sequential)。第45-48依次遍历此nn.Sequential。到了最底层后遍历完毕。

    第49-50为第三个nn.Sequential的最后两层。

    第51次为Inception的第四个nn.Sequential,第52-53次为该层的具体层。此时已经到了第四个最底层。

    至此,第二个Inception层通过深度优先的方式遍历完毕。

    第54-57为最后的两个层。

    由上面可以看出,apply采用的是深度优先的方式进行遍历。

  • 相关阅读:
    BZOJ3282 Tree
    [NOI2004] 郁闷的出纳员
    [HNOI2004]宠物收养所
    [HNOI2002] 营业额统计
    图论 简单学习笔记
    POJ3321 Apple tree
    [国家集训队] 聪聪可可
    POJ2976 Dropping tests
    SCOI2005 最大子矩阵
    codeforces|CF13C Sequence
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/6221633.html
Copyright © 2020-2023  润新知