• torch 深度学习(4)


    torch 深度学习(4)

    经过数据的预处理、模型创建、损失函数定义以及模型的训练,现在可以使用训练好的模型对测试集进行测试了。测试模块比训练模块简单的多,只需调用模型就可以了

    测试模块

    1. 加载模块

    require 'torch'
    require 'xlua'      -- 主要使用进度条用到
    require 'optim'     -- confusionMatrix和Logger会用到
    

    这里多说一句,为什么每个模块都没有调用之前的模块呢?这是因为我们我们最终是将之前的预处理、建模等模块放到一个项目里面一起end-to-end执行的,而且模块里面的共享参数都是全局变量所以不需要再加载之前的模块了。如果单独执行某一个文件,比如_'3_loss.lua' 文件,里面为了能够运行,创建了一个model=nn.Sequential()只是为了运行没有实际意义,像 '4_train.lua' 模块不加载之前的模块则不能单独运行。

    1. 测试函数

    function test()
        local time = sys.time()
        
        for t=1,testData:size() do
            xlua.progress(t, testData:size())
            
            local input=testData.data[t]:double()
            local target = testData.labels[t]
            
            pred = model:forward(input)   --使用模型预测
            
            local _,indices = torch.sort(pred,true) --降序排列
            confusion:add(indices[1],target) --注意这里的混淆矩阵是在4_train.lua中定义的,每次都清零了,所以没有影响
        end
        
        time=sys.clock()-time
        time=time/testData:size()  -- 单位所需时间
        print('==> time to test 1 sample =' .. (time*1000) .. 'ms')   -- ms单位
        print(confusion) --打印混淆矩阵
        testLogger.add{['% mean class accuracy (test set)'] = confusion.totalValid*100}
        if opt.plot then
            testLogger:style('-') --折线图
            testLogger:plot()  --结果变化趋势图
        end
        confusion:zero() --reset confusionMatrix
    end
    

    项目执行

    将所有的模块放到一起统一执行

    1. 加载模块

    require 'torch'  -- 其他模块需要的包他们自己加载
    
    1. 设置参数,注意这里设置了opt参数,那么其他模块的命令行参数设置的代码块都不会执行

    cmd = torch.CmdLine()
    cmd:text()
    cmd:text('参数设置')
    cmd:text()
    cmd:text('Options:')
    cmdLtext()
    cmd:option('-seed',1,'fixed input seed for repeatable experiments') --因为代码涉及到随机数,为了实验的可重复性,设置固定的随机种子值
    cmd:option('-size','small','how many samples do we load: small | full | extra')
    cmd:option('-model','convnet','type of model to construct: linear | mlp | convnet')
    cmd:option('-loss','nll','type of loss function to minimization: nll | mse | margin')
    
    cmd:option('-save','results','subdirectory to save/log experiments in')
    cmd:option('-plot',false,'live plot')
    cmd:option('-optimization','SGD','optimization method: SGD | ASGD | LBFGS | CG')
    cmd:option('-batchSize',10,'mini-batchSize (1= pure stochastic)')
    cmd:option('-learningRate',1e-3,'learning rate at t=0')
    cmd:option('-weightDecay',0, 'weight decay(SGD only)')
    cmd:option('-momentum',0,'momentum(SGD only)')
    cmd:option('-t0',1,'start average at t0 (ASGD only) in nb of epochs')
    cmd:option('-maxIter',2,'maximum nb of iteration for CG and LBFGS')
    cmd:text()
    opt=cmd:parse(arg or {})
    
    torch.setnumthread(4)  --设置并行的线程数,这个不能设置太大,因为线程切换也需要时间,而且他们共用模型参数
    torch.manualSeed(opt.seed)  --设置随机种子
    
    1. 依次执行模块

    dofile '1_data.lua'
    dofile '2_model.lua'
    dofile '3_loss.lua'
    dofile '4_train.lua'
    dofile '5_test.lua'
    

    dofile 是lua语言里面的函数, loadfile 编译不运行, dofile 运行文件,参见Lua中require,dofile、dofile的区别

    1. 训练并测试

    while true do 
        train()
        test()
        if epoch == 30 then 
            break;
        end
        if epoch == 27 then
            opt.plot=true
        end
    end
    

    这里我执行了30个周期,并且在输出后4个周期的实验结果,这里是指结果变化曲线图

    实验结果

    1.混淆矩阵(第27次)

    enter description here

    27th epoch.png

    1. 混淆矩阵的可视化显示 render,这个图像的对角线表示正确对应类的正确率,每一行非对角就是错分别别的类的比率
      Confusion matrix

    enter description here

    confusion.png

    1. 日志文件

    enter description here

    log.png

    1. 模型性能变化趋势

    enter description here

    plot.png

    观测这两个图可以发现训练集曲线是单调递增的,这是因为优化算法目标就是让性能不断上升,但也只是能够保证训练集,而对于测试集显然有时候精度反而下降了。这也体现了泛化的概念,训练集好不一定训练集也好,可能过拟合

    1. 这段代码都是跑的small规模数据集,10000个训练样本,大约80s完成一次epoch的训练。我是用full规模数据训练,7万多数据训练跑到50次左右训练集结果95%左右,测试集93%左右,时间训练一轮大约13分钟

  • 相关阅读:
    maven中没找到settings.xml文件怎么办,简单粗暴
    如何修改新建后的maven的jdk版本号,简单粗暴
    如何修改maven下载的jar包存放位置,简单粗暴方法
    Kafka 温故(一):Kafka背景及架构介绍
    八、Kafka总结
    七、Kafka 用户日志上报实时统计之编码实践
    六、Kafka 用户日志上报实时统计之分析与设计
    五、Kafka 用户日志上报实时统计之 应用概述
    四、Kafka 核心源码剖析
    三、消息处理过程与集群维护
  • 原文地址:https://www.cnblogs.com/YiXiaoZhou/p/6326612.html
Copyright © 2020-2023  润新知