• (原)torch的训练过程


    转载请注明出处:

    http://www.cnblogs.com/darkknightzh/p/6221622.html

    参考网址:

    http://ju.outofmemory.cn/entry/284587

    https://github.com/torch/nn/blob/master/doc/criterion.md

    1. 使用updateParameters

    假设已经有了model=setupmodel(自己建立的模型),同时也有自己的训练数据input,实际输出outReal,以及损失函数criterion(参见第二个网址),则使用torch训练过程如下:

    1 -- given model, criterion, input, outReal
    2 model:training()
    3 model:zeroGradParameters()
    4 outPredict = model:forward(input)
    5 err= criterion:forward(outPredict, outReal)
    6 grad_criterion = criterion:backward(outPredict, outReal)
    7 model:backward(input, grad_criterion)
    8 model:updateParameters(learningRate)

    上面第1行假定已知的参数

    第2行设置为训练模式

    第3行将model中每个模块保存的梯度清零(防止之前的干扰此次迭代)

    第4行将输入input通过model,得到预测的输出outPredict

    第5行通过损失函数计算在当前参数下模型的预测输出outPredict和实际输出outReal的误差err

    第6行通过预测输出outPredict和实际输出outReal计算损失函数的梯度grad_criterion

    第7行反向计算model中每个模块的梯度

    第8行更新model每个模块的参数

    每次迭代时,均需要执行第3行至第8行。

    =========================================================

    2. 使用optim

    170301更新:

    http://x-wei.github.io/learn-torch-6-optim.html

    中给出了更方便的方式(是否方便也说不清楚),可以使用torch中的optim来更新参数(直接使用model:updateParameters的话,只能使用最简单的梯度下降算法,optmi中封装了很多算法,梯度下降,adam之类的)。

    params_new, fs, ... = optim._method_(feval, params[, config][, state])

    其中,param:当前参数向量(1D的tensro),在优化时会被更新

    feval:用户自定义的闭包,类似于f, df/dx = feval(x)

    config:一个包含算法参数(如learning rate)的table

    state:包含状态变量的table

    params_new:最小化函数f的新的结果参数(1D的tensor)

    fs:a table of f values evaluated during the optimization, fs[#fs] is the optimized function value

    注意:由于optmi需要输入数据为1D的tensor,因而需要将模型中的参数变成拉平,通过下面的函数来实现:

    params, gradParams = model:getParameters()

    params和gradParams均为1D的tensor。

    使用上面的方法后,开始得程序可以修改为:

    -- given model, criterion, input, outReal, optimState
    local params, gradParams = model:getParameters()
    
    local function feval()
        return criterion.output, gradParams 
    end
    
    for ...
        model:training()
        model:zeroGradParameters()
        outPredict = model:forward(input)
        err= criterion:forward(outPredict, outReal)
        grad_criterion = criterion:backward(outPredict, outReal)
        model:backward(input, grad_criterion)
        
        optim.sgd(feval, params, optimState)
    end

     170301更新结束

    =========================================================

    3. 使用model:backward注意的问题

    170405更新

    需要注意的是,model:backward一定要和model:forward对应。

    https://github.com/torch/nn/blob/master/doc/module.md中[gradInput] backward(input, gradOutput)写着:

    In general this method makes the assumption forward(input) has been called before, with the same input. This is necessary for optimization reasons. If you do not respect this rule, backward() will compute incorrect gradients.

    应该是由于backward时,可能会使用forward的某些中间变量,因而backward执行前,必须先执行forward,否则中间变量和backward不对应,导致结果错误。

    我这边之前的程序由于最初forward后,保存的是最后一次forward时的中间变量,因而backward时的结果总是不正确(见method5注释的那句)。

    只能使用比较坑的方式解决,之前先forward,最终在backward之前,在forward一次,这样能保证结果正确(缺点就是增加了一次计算。。。),代码如method5。

    说明:method1为常规的batch的方法。但是该方法对显存要求较高。因而可以使用类似caffe中的iter_size的方式,如method2的方法(和caffe中的iter_size不完全一样)。如果需要使用更多的样本,同时criterion时使用尽可能多的样本,则前两种方法均会出现问题,此时可以使用method3的方法(但是实际上method3有问题,loss收敛的很慢)。method4对method3进行了进一步的改进及测试,如果method4注释那两行,则其收敛正常,但是不注释那两行,则收敛出现问题,和method3收敛类似。method5进行了最终的改进。该程序能正常收敛。同时为了验证forward和backward要对应,将method5中注释的取消注释,同时注释掉上面一行,可以看出,其收敛很慢(和method3,4类似)。下面是各种method前10次的的收敛情况。

    程序如下:

      1 require 'torch'
      2 require 'nn'
      3 require 'optim'
      4 require 'cunn'
      5 require 'cutorch'
      6 local mnist = require 'mnist'
      7 
      8 local fullset = mnist.traindataset()
      9 local testset = mnist.testdataset()
     10 
     11 local trainset = {
     12     size = 50000,
     13     data = fullset.data[{{1,50000}}]:double(),
     14     label = fullset.label[{{1,50000}}]
     15 }
     16 trainset.data = trainset.data - trainset.data:mean()
     17 trainset.data = trainset.data:cuda()
     18 trainset.label = trainset.label:cuda()
     19 
     20 local validationset = {
     21     size = 10000,
     22     data = fullset.data[{{50001,60000}}]:double(),
     23     label = fullset.label[{{50001,60000}}]
     24 }
     25 validationset.data = validationset.data - validationset.data:mean()
     26 validationset.data = validationset.data:cuda()
     27 validationset.label = validationset.label:cuda()
     28 
     29 local model = nn.Sequential()
     30 model:add(nn.Reshape(1, 28, 28))
     31 model:add(nn.MulConstant(1/256.0*3.2))
     32 model:add(nn.SpatialConvolutionMM(1, 20, 5, 5, 1, 1, 0, 0))
     33 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))
     34 model:add(nn.SpatialConvolutionMM(20, 50, 5, 5, 1, 1, 0, 0))
     35 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))
     36 model:add(nn.Reshape(4*4*50))
     37 model:add(nn.Linear(4*4*50, 500))
     38 model:add(nn.ReLU())
     39 model:add(nn.Linear(500, 10))
     40 model:add(nn.LogSoftMax())
     41 
     42 model = require('weight-init')(model, 'xavier')
     43 model = model:cuda()
     44 
     45 x, dl_dx = model:getParameters()
     46 
     47 local criterion = nn.ClassNLLCriterion():cuda()
     48 
     49 local sgd_params = {
     50    learningRate = 1e-2,
     51    learningRateDecay = 1e-4,
     52    weightDecay = 1e-3,
     53    momentum = 1e-4
     54 }
     55 
     56 local training = function(batchSize)
     57     local current_loss = 0
     58     local count = 0
     59     local shuffle = torch.randperm(trainset.size)
     60     batchSize = batchSize or 200
     61     for t = 0, trainset.size-1, batchSize do
     62         -- setup inputs and targets for batch iteration
     63         local size = math.min(t + batchSize, trainset.size) - t
     64         local inputs = torch.Tensor(size, 28, 28):cuda()
     65         local targets = torch.Tensor(size):cuda()
     66         for i = 1, size do
     67             inputs[i] = trainset.data[shuffle[i+t]]
     68             targets[i] = trainset.label[shuffle[i+t]] + 1
     69         end
     70 
     71         local feval = function(x_new)
     72             local miniBatchSize = 20
     73             if x ~= x_new then x:copy(x_new) end   -- reset data
     74             dl_dx:zero()
     75 
     76             --[[ ------------------ method 1 original batch
     77             local outputs = model:forward(inputs)
     78             local loss = criterion:forward(outputs, targets)
     79             local gradInput = criterion:backward(outputs, targets)
     80             model:backward(inputs, gradInput)
     81             --]]
     82 
     83             --[[ ------------------ method 2 iter-size with batch
     84             local loss = 0
     85             for idx = 1, batchSize, miniBatchSize do
     86                 local outputs = model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])
     87                 loss = loss + criterion:forward(outputs, targets[{{idx, idx + miniBatchSize - 1}}])
     88                 local gradInput = criterion:backward(outputs, targets[{{idx, idx + miniBatchSize - 1}}])
     89                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput)
     90             end
     91             dl_dx:mul(1.0 * miniBatchSize / batchSize)
     92             loss = loss * miniBatchSize / batchSize
     93             --]]
     94 
     95             --[[  ------------------ method 3 mini-batch in batch
     96             local outputs = torch.Tensor(batchSize, 10):zero():cuda()
     97             for idx = 1, batchSize, miniBatchSize do
     98                 outputs[{{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{{idx, idx + miniBatchSize - 1}}]))
     99             end
    100             local loss = 0
    101             for idx = 1, batchSize, miniBatchSize do
    102                 loss = loss + criterion:forward(outputs[{{idx, idx + miniBatchSize - 1}}], 
    103                     targets[{{idx, idx + miniBatchSize - 1}}])
    104             end
    105             local gradInput = torch.Tensor(batchSize, 10):zero():cuda()
    106             for idx = 1, batchSize, miniBatchSize do
    107                 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(
    108                     outputs[{{idx, idx + miniBatchSize - 1}}], targets[{{idx, idx + miniBatchSize - 1}}]))
    109             end
    110             for idx = 1, batchSize, miniBatchSize do
    111                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}])
    112             end
    113             dl_dx:mul( 1.0 * miniBatchSize / batchSize)
    114             loss = loss * miniBatchSize / batchSize
    115             --]]
    116 
    117             --[[ ------------------ method 4 mini-batch in batch
    118             local outputs = torch.Tensor(batchSize, 10):zero():cuda()
    119             local loss = 0
    120             local gradInput = torch.Tensor(batchSize, 10):zero():cuda()
    121             for idx = 1, batchSize, miniBatchSize do
    122                 outputs[{{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{{idx, idx + miniBatchSize - 1}}]))
    123                 loss = loss + criterion:forward(outputs[{{idx, idx + miniBatchSize - 1}}], 
    124                     targets[{{idx, idx + miniBatchSize - 1}}])
    125                 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(
    126                     outputs[{{idx, idx + miniBatchSize - 1}}], targets[{{idx, idx + miniBatchSize - 1}}]))
    127             --  end
    128             --  for idx = 1, batchSize, miniBatchSize do
    129                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}])
    130             end
    131            
    132             dl_dx:mul( 1.0 * miniBatchSize / batchSize)
    133             loss = loss * miniBatchSize / batchSize
    134             --]]
    135 
    136 
    137             ---[[ ------------------ method 5 mini-batch in batch
    138             local loss = 0
    139             local gradInput = torch.Tensor(batchSize, 10):zero():cuda()
    140 
    141             for idx = 1, batchSize, miniBatchSize do
    142                 local outputs = model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])
    143                 loss = loss + criterion:forward(outputs, targets[{{idx, idx + miniBatchSize - 1}}])
    144                 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(outputs, targets[{{idx, idx + miniBatchSize - 1}}]))
    145             end
    146 
    147             for idx = 1, batchSize, miniBatchSize do
    148                 model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])
    149                 --model:forward(inputs[{{batchSize - miniBatchSize + 1, batchSize}}])
    150                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}])
    151             end
    152 
    153             dl_dx:mul( 1.0 * miniBatchSize / batchSize)
    154             loss = loss * miniBatchSize / batchSize
    155             --]]
    156 
    157             return loss, dl_dx
    158         end
    159 
    160         _, fs = optim.sgd(feval, x, sgd_params)
    161 
    162         count = count + 1
    163         current_loss = current_loss + fs[1]
    164     end
    165 
    166     return current_loss / count   -- normalize loss
    167 end
    168 
    169 local eval = function(dataset, batchSize)
    170     local count = 0
    171     batchSize = batchSize or 200
    172 
    173     for i = 1, dataset.size, batchSize do
    174         local size = math.min(i + batchSize - 1, dataset.size) - i
    175         local inputs = dataset.data[{{i,i+size-1}}]:cuda()
    176         local targets = dataset.label[{{i,i+size-1}}]
    177         local outputs = model:forward(inputs)
    178         local _, indices = torch.max(outputs, 2)
    179         indices:add(-1)
    180         indices = indices:cuda()
    181         local guessed_right = indices:eq(targets):sum()
    182         count = count + guessed_right
    183     end
    184 
    185     return count / dataset.size
    186 end
    187 
    188 
    189 local max_iters = 50
    190 local last_accuracy = 0
    191 local decreasing = 0
    192 local threshold = 1 -- how many deacreasing epochs we allow
    193 for i = 1,max_iters do
    194    -- timer = torch.Timer()
    195 
    196     model:training()
    197     local loss = training()
    198 
    199     model:evaluate()
    200     local accuracy = eval(validationset)
    201     print(string.format('Epoch: %d Current loss: %4f; validation set accu: %4f', i, loss, accuracy))
    202     if accuracy < last_accuracy then
    203         if decreasing > threshold then break end
    204         decreasing = decreasing + 1
    205     else
    206         decreasing = 0
    207     end
    208     last_accuracy = accuracy
    209 
    210     --print('    Time elapsed: ' .. i .. 'iter: ' .. timer:time().real .. ' seconds')
    211 end
    212 
    213 testset.data = testset.data:double()
    214 eval(testset)
    View Code

    weight-init.lua

     1 --
     2 -- Different weight initialization methods
     3 --
     4 -- > model = require('weight-init')(model, 'heuristic')
     5 --
     6 require("nn")
     7 
     8 
     9 -- "Efficient backprop"
    10 -- Yann Lecun, 1998
    11 local function w_init_heuristic(fan_in, fan_out)
    12    return math.sqrt(1/(3*fan_in))
    13 end
    14 
    15 -- "Understanding the difficulty of training deep feedforward neural networks"
    16 -- Xavier Glorot, 2010
    17 local function w_init_xavier(fan_in, fan_out)
    18    return math.sqrt(2/(fan_in + fan_out))
    19 end
    20 
    21 -- "Understanding the difficulty of training deep feedforward neural networks"
    22 -- Xavier Glorot, 2010
    23 local function w_init_xavier_caffe(fan_in, fan_out)
    24    return math.sqrt(1/fan_in)
    25 end
    26 
    27 -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
    28 -- Kaiming He, 2015
    29 local function w_init_kaiming(fan_in, fan_out)
    30    return math.sqrt(4/(fan_in + fan_out))
    31 end
    32 
    33 local function w_init(net, arg)
    34    -- choose initialization method
    35    local method = nil
    36    if     arg == 'heuristic'    then method = w_init_heuristic
    37    elseif arg == 'xavier'       then method = w_init_xavier
    38    elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe
    39    elseif arg == 'kaiming'      then method = w_init_kaiming
    40    else
    41       assert(false)
    42    end
    43 
    44    -- loop over all convolutional modules
    45    for i = 1, #net.modules do
    46       local m = net.modules[i]
    47       if m.__typename == 'nn.SpatialConvolution' then
    48          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
    49       elseif m.__typename == 'nn.SpatialConvolutionMM' then
    50          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
    51       elseif m.__typename == 'cudnn.SpatialConvolution' then
    52          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
    53       elseif m.__typename == 'nn.LateralConvolution' then
    54          m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
    55       elseif m.__typename == 'nn.VerticalConvolution' then
    56          m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
    57       elseif m.__typename == 'nn.HorizontalConvolution' then
    58          m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
    59       elseif m.__typename == 'nn.Linear' then
    60          m:reset(method(m.weight:size(2), m.weight:size(1)))
    61       elseif m.__typename == 'nn.TemporalConvolution' then
    62          m:reset(method(m.weight:size(2), m.weight:size(1)))            
    63       end
    64 
    65       if m.bias then
    66          m.bias:zero()
    67       end
    68    end
    69    return net
    70 end
    71 
    72 return w_init
    View Code
    Method 1
    
    Epoch: 1 Current loss: 0.616950; validation set accu: 0.920900	
    Epoch: 2 Current loss: 0.228665; validation set accu: 0.942400	
    Epoch: 3 Current loss: 0.168047; validation set accu: 0.957900	
    Epoch: 4 Current loss: 0.134796; validation set accu: 0.961800	
    Epoch: 5 Current loss: 0.113071; validation set accu: 0.966200	
    Epoch: 6 Current loss: 0.098782; validation set accu: 0.968800	
    Epoch: 7 Current loss: 0.088252; validation set accu: 0.970000	
    Epoch: 8 Current loss: 0.080225; validation set accu: 0.971200	
    Epoch: 9 Current loss: 0.073702; validation set accu: 0.972200	
    Epoch: 10 Current loss: 0.068171; validation set accu: 0.972400	
    
    method 2
    Epoch: 1 Current loss: 0.624633; validation set accu: 0.922200	
    Epoch: 2 Current loss: 0.238459; validation set accu: 0.945200	
    Epoch: 3 Current loss: 0.174089; validation set accu: 0.959000	
    Epoch: 4 Current loss: 0.140234; validation set accu: 0.963800	
    Epoch: 5 Current loss: 0.116498; validation set accu: 0.968000	
    Epoch: 6 Current loss: 0.101376; validation set accu: 0.968800	
    Epoch: 7 Current loss: 0.089484; validation set accu: 0.972600	
    Epoch: 8 Current loss: 0.080812; validation set accu: 0.973000	
    Epoch: 9 Current loss: 0.073929; validation set accu: 0.975100	
    Epoch: 10 Current loss: 0.068330; validation set accu: 0.975400	
    
    method 3
    Epoch: 1 Current loss: 2.202240; validation set accu: 0.548500	
    Epoch: 2 Current loss: 2.049710; validation set accu: 0.669300	
    Epoch: 3 Current loss: 1.993560; validation set accu: 0.728900	
    Epoch: 4 Current loss: 1.959818; validation set accu: 0.774500	
    Epoch: 5 Current loss: 1.945992; validation set accu: 0.757600	
    Epoch: 6 Current loss: 1.930599; validation set accu: 0.809600	
    Epoch: 7 Current loss: 1.911803; validation set accu: 0.837200	
    Epoch: 8 Current loss: 1.904754; validation set accu: 0.842100	
    Epoch: 9 Current loss: 1.903705; validation set accu: 0.846400	
    Epoch: 10 Current loss: 1.903911; validation set accu: 0.848100	
    
    method 4
    Epoch: 1 Current loss: 0.624240; validation set accu: 0.924900	
    Epoch: 2 Current loss: 0.213469; validation set accu: 0.948500	
    Epoch: 3 Current loss: 0.156797; validation set accu: 0.959800	
    Epoch: 4 Current loss: 0.126438; validation set accu: 0.963900	
    Epoch: 5 Current loss: 0.106664; validation set accu: 0.965900	
    Epoch: 6 Current loss: 0.094166; validation set accu: 0.967200	
    Epoch: 7 Current loss: 0.084848; validation set accu: 0.971200	
    Epoch: 8 Current loss: 0.077244; validation set accu: 0.971800	
    Epoch: 9 Current loss: 0.071417; validation set accu: 0.973300	
    Epoch: 10 Current loss: 0.065737; validation set accu: 0.971600	
    
    
    取消注释
    Epoch: 1 Current loss: 2.178319; validation set accu: 0.542200	
    Epoch: 2 Current loss: 2.031493; validation set accu: 0.648700	
    Epoch: 3 Current loss: 1.982282; validation set accu: 0.703700	
    Epoch: 4 Current loss: 1.956709; validation set accu: 0.762700	
    Epoch: 5 Current loss: 1.927590; validation set accu: 0.808100	
    Epoch: 6 Current loss: 1.924535; validation set accu: 0.817200	
    Epoch: 7 Current loss: 1.911364; validation set accu: 0.820100	
    Epoch: 8 Current loss: 1.898206; validation set accu: 0.855400	
    Epoch: 9 Current loss: 1.885394; validation set accu: 0.836500	
    Epoch: 10 Current loss: 1.880787; validation set accu: 0.870200	
    
    
    method 5
    
    Epoch: 1 Current loss: 0.619814; validation set accu: 0.924300	
    Epoch: 2 Current loss: 0.232870; validation set accu: 0.948800	
    Epoch: 3 Current loss: 0.172606; validation set accu: 0.954900	
    Epoch: 4 Current loss: 0.137763; validation set accu: 0.961800	
    Epoch: 5 Current loss: 0.116268; validation set accu: 0.967700	
    Epoch: 6 Current loss: 0.101985; validation set accu: 0.968800	
    Epoch: 7 Current loss: 0.091154; validation set accu: 0.970900	
    Epoch: 8 Current loss: 0.083219; validation set accu: 0.972700	
    Epoch: 9 Current loss: 0.074921; validation set accu: 0.972800	
    Epoch: 10 Current loss: 0.070208; validation set accu: 0.972800	
    
    
    取消注释,同时注释上面一行
    
    Epoch: 1 Current loss: 2.161032; validation set accu: 0.497500	
    Epoch: 2 Current loss: 2.027255; validation set accu: 0.690900	
    Epoch: 3 Current loss: 1.972939; validation set accu: 0.767600	
    Epoch: 4 Current loss: 1.940982; validation set accu: 0.766000	
    Epoch: 5 Current loss: 1.933135; validation set accu: 0.812800	
    Epoch: 6 Current loss: 1.913039; validation set accu: 0.799300	
    Epoch: 7 Current loss: 1.896871; validation set accu: 0.848800	
    Epoch: 8 Current loss: 1.899655; validation set accu: 0.854400	
    Epoch: 9 Current loss: 1.889465; validation set accu: 0.845700	
    Epoch: 10 Current loss: 1.878703; validation set accu: 0.846400	
    View Code

    170301更新结束

    =========================================================

  • 相关阅读:
    几种 JavaScript 动画库推荐
    微软工程师为你推荐了十本程序员必读书目
    前端新老手必备的34种JavaScript简写优化技术
    Airbnb 爱彼迎 visx 项目介绍
    开源中间件技术支持(5000+元/天)
    C# Byte数组与Int16数组之间的转换(转)
    【636】K.sum 与 np.sum 的区别
    【635】语义分割 label 通道与模型输出通道的
    【634】ndarray 提取行列进行任意变换 & 相关 ndarray 操作
    面试官:设计一个安全的登录都要考虑哪些?我一脸懵逼。。
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/6221622.html
Copyright © 2020-2023  润新知