• 快速入手一个简单的分类网络



    本系列文章由 @yhl_leo 出品,转载请注明出处。
    文章链接: http://blog.csdn.net/yhl_leo/article/details/53727411


    在以前的一篇博客中,我整理了如何根据CIFAR10的数据组织方式,制作自己的数据集,然后略微调整tensorflow 提供的demo进行训练,获得了一些关注,现在重新公布一个简单的方法,不需要制作像CIFAR10那样的数据集,也不用lmdb数据格式,直接使用原始数据,利用caffe训练简单的分类网络。

    发布于GitHub: yhlleo/CreateSimpleNetworks.

    在caffe的layer中,已有image_data_layer,对于image+label类型的训练数据,数据读取过程很简单:

      LOG(INFO) << "Opening file " << source;
      std::ifstream infile(source.c_str());
      string line;
      size_t pos;
      int label;
      while (std::getline(infile, line)) {
        pos = line.find_last_of(' ');
        label = atoi(line.substr(pos + 1).c_str());
        lines_.push_back(std::make_pair(line.substr(0, pos), label));
      }
    
      CHECK(!lines_.empty()) << "File is empty";
    
      if (this->layer_param_.image_data_param().shuffle()) {
        // randomly shuffle data
        LOG(INFO) << "Shuffling data";
        const unsigned int prefetch_rng_seed = caffe_rng_rand();
        prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
        ShuffleImages();
      }
      LOG(INFO) << "A total of " << lines_.size() << " images.";

    即,需要制作训练文件列表格式为:

    ...
    /path/img1.jpg 0
    /path/img2.jpg 1
    ...

    完成训练文件列表后,简单搭建起一个小型网络:

    model

    指定好train.prototxtsolver.prototxtdeploy.prototxt文件,就可以训练。

    启动训练:

    ## train.py ##
    from __future__ import division
    import numpy as np
    import sys
    caffe_root = '/path/caffe/' 
    sys.path.insert(0, caffe_root)
    import caffe
    
    # init
    caffe.set_mode_gpu()
    caffe.set_device(0)
    
    solver = caffe.SGDSolver('/path/Models/solver.prototxt')
    solver.step(60000)

    批量测试:

    import numpy as np
    import os, cv2
    import time
    import caffe
    
    # Make sure that caffe is on the python path:
    caffe_root = '/path/caffe/'  
    import sys
    sys.path.insert(0, caffe_root + 'python')
    
    caffe.set_mode_gpu()
    caffe.set_device(0)
    
    def findImages(dir,topdown=True):
        im_list = []
        if not os.path.exists(dir):
            print "Path for {} not exist!".format(dir)
            raise
        else:
            for root, dirs, files in os.walk(dir, topdown):
                for fl in files:
                    im_list.append(os.path.join(root, fl))
        return im_list
    
    data_root = '/path/test/test1'
    test_lst = findImages(data_root)
    savefolder = '/path/test/'
    name = 'test1.txt'
    OutDir = open(savefolder+name, 'w');
    
    net = caffe.Net('/path/Models/xh_deploy.prototxt', 
        '/path/train/net_iter_60000.caffemodel', caffe.TEST)
    time_consum = []
    
    for idx in range(len(test_lst)):
        im = cv2.imread(test_lst[idx], cv2.IMREAD_UNCHANGED)
        sp = im.shape
    
        in_ = np.array(im, dtype=np.float32)
        in_ = in_[:,:,::-1]
        in_ = in_.transpose((2,0,1))
        net.blobs['data'].reshape(1, *in_.shape)
        net.blobs['data'].data[...] = in_
    
        start =time.clock()
        net.forward()   
        end = time.clock()
        time_consum.append(end-start)
    
        fuse = net.blobs['prob'].data[0]
        fname = test_lst[idx].split('/')[-1]
        OutDir.write("%s %.3f %.3f %.3f
    "%(fname, fuse[0], fuse[1], fuse[2]))
    print sum(time_consum)/len(time_consum)
    OutDir.close()

    测试结果(数据集分为两类),因此四列分别对应着:文件名,label为0的概率,label为1的概率和其它类别的概率:

    1-1.jpg 1.000 0.000 0.000
    1-2.jpg 1.000 0.000 0.000
    1-3.jpg 1.000 0.000 0.000
    1-4.jpg 1.000 0.000 0.000
    1-5.jpg 1.000 0.000 0.000
    1-6.jpg 1.000 0.000 0.000
    1-7.jpg 1.000 0.000 0.000
    1-8.jpg 1.000 0.000 0.000
    1001-1.jpg 0.594 0.405 0.001
    1002-1.jpg 0.009 0.990 0.000
    1002-10.jpg 1.000 0.000 0.000
    ...
  • 相关阅读:
    软件可靠性与安全性设计与实现知识梳理(软件可靠性与安全性高级技术研讨会心得)
    SSM框架整合
    不注册Tomcat服务,运行Tomcat不弹出JAVA控制台窗口
    ExtJS表单之复选框CheckboxGroup展示与取值
    ExtJS获取父子、兄弟容器元素方法
    LabVIEW之生产者/消费者模式--队列操作 彭会锋
    ExtJS Grid导出excel文件
    jeesite部署到Tomcat后,无法访问,cannot be resolved in either web.xml or the jar files deployed with this application
    滚动轮播插件——jCarouselLite
    统计学基础之假设检验
  • 原文地址:https://www.cnblogs.com/hehehaha/p/6332107.html
Copyright © 2020-2023  润新知