• caffe源码阅读(3)-Datalayer


    DataLayer是把数据从文件导入到网络的层,从网络定义prototxt文件可以看一下数据层定义

    layer {
      name: "data"
      type: "Data"
      top: "data"
      top: "label"
      include {
        phase: TRAIN
      }
      transform_param {
        mirror: true
        crop_size: 224
        mean_value: 104
        mean_value: 117
        mean_value: 123
      }
      data_param {
        source: "examples/imagenet/ilsvrc12_train_lmdb"
        batch_size: 32
        backend: LMDB
      }
    }
    l
    

    数据层包括了文件位置、文件类型、bath_size大小、图片变换等一些参数。可以看书,datalayer之后有top,没有bottom,即它是最底层的,它的forward运算只是负责把数据填充到top即可,并不使用bottom。

    在caffe中数据层不仅仅限于DataLayer,因为常常使用DataLayer导入数据,这里只是阅读DataLayer部分。

    数据层相关代码定义在data_layers.hpp中,DataLayer是从其他类派生出来的,一层一层来阅读。除了用到了datalayer相关的类,还用到了InternalThread,用来封装了线程,使用线程函数来取数据;类BlockingQueue是一个阻塞队列,用来辅助取数据;类DataReader,用来从文件读数据。

    BaseDataLayer

    BaseDataLayer直接从Layer派生出来。其成员变量有

      TransformationParameter transform_param_;//具体在protobuf中
      shared_ptr<DataTransformer<Dtype> > data_transformer_;//和输入数据转换相关。流入scale,crop,mirror等
      bool output_labels_;//是否有标签,无标签可以是无监督学习
    

    TransformationParameter是图像变换一些相关的参数,例如图像缩放、镜像变换、crop、减去均值等操作。
    DataTransformer类实现了图像变换的函数。

    Batch

    Batch是和批相关的类,只是把2个数据结构封装为一个,把数据和标签对应起来。

    template <typename Dtype>
    class Batch {
     public:
      Blob<Dtype> data_, label_;
    };
    

    BasePrefetchingDataLayer

    BasePrefetchingDataLayer派生自BaseDataLayerInternalThread。其中InternalThread是封装了线程,通过虚函数InternalThreadEntry来执行线程函数,用一个单独的线程函数来取数据。
    成员变量为:

      Batch<Dtype> prefetch_[PREFETCH_COUNT];
      BlockingQueue<Batch<Dtype>*> prefetch_free_;
      BlockingQueue<Batch<Dtype>*> prefetch_full_;
      Blob<Dtype> transformed_data_;//用来辅助实现图片变换操作
    

    PREFETCH_COUNT的大小,程序设为3,为了提前填充free队列。两个阻塞队列,逻辑功能比较简单:从free队列取数据结构,填充数据结构放到full队列;从full队列取数据,使用数据,清空数据结构,放到free队列。还有一个Blob结构,用来当做中间变量辅助图像变换。
    虚函数InternalThreadEntry是线程执行的函数,用来取数据

    //这里是取数据的线程
    template <typename Dtype>
    void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() {
    #ifndef CPU_ONLY
      cudaStream_t stream;
      if (Caffe::mode() == Caffe::GPU) {
        CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
      }
    #endif
    
      try {
        while (!must_stop()) {
          Batch<Dtype>* batch = prefetch_free_.pop();//从free_队列去数据结构
          load_batch(batch);//取数据,填充数据结构。在其派生类实现的
    #ifndef CPU_ONLY
          if (Caffe::mode() == Caffe::GPU) {
            batch->data_.data().get()->async_gpu_push(stream);//异步,把数据同步到GPU,使用Syncedmem->async_gpu_push
            CUDA_CHECK(cudaStreamSynchronize(stream));
          }
    #endif
          prefetch_full_.push(batch);//把数据放到full_队列
        }
      } catch (boost::thread_interrupted&) {
        // Interrupted exception is expected on shutdown
      }
    #ifndef CPU_ONLY
      if (Caffe::mode() == Caffe::GPU) {
        CUDA_CHECK(cudaStreamDestroy(stream));
      }
    #endif
    }
    

    数据层的forward函数不进行计算,不使用bottom,只是准备数据,填充到top

    template <typename Dtype>
    void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
        const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
      Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");//从full队列取数据
      // Reshape to loaded data.
      top[0]->ReshapeLike(batch->data_);//调整top大小,一次读取一个batch大小的数据
      // Copy the data。把数据拷贝到top中
      caffe_copy(batch->data_.count(), batch->data_.cpu_data(),
                 top[0]->mutable_cpu_data());
      DLOG(INFO) << "Prefetch copied";
      if (this->output_labels_) {//如果有标签,也要把标签拷贝到top中
        // Reshape to loaded labels.
        top[1]->ReshapeLike(batch->label_);
        // Copy the labels.
        caffe_copy(batch->label_.count(), batch->label_.cpu_data(),
            top[1]->mutable_cpu_data());
      }
    
      prefetch_free_.push(batch);//用过的数据结构,放回free队列
    }
    

    DataLayer

    DataLayer是真正在网络中使用的类,派生自BasePrefetchingDataLayer。成员变量为:

    DataReader reader_;
    

    DataReader负责从硬盘读数据到一个队列,之后提供给data_layer使用.即使并行运行多个solver,也只有一个线程来读数据,这样可以确保'顺序'取数据,不同的solver取到的数据不同.
    DataReader的没有bottom,top中,如果没有标签,blob数量为1;有标签blob数量为2。
    虚函数load_batch,一次导入一个batch_size大小的数据;之后进行DataTransformer变换。

  • 相关阅读:
    LeetCode121.买卖股票的最佳时机
    OpenFunction 应用系列之一: 以 Serverless 的方式实现 Kubernetes 日志告警
    KubeSphere 核心架构浅析
    云原生爱好者周刊:服务网格的困境与破局
    DG:11.2.0.4 RAC在线duplicate恢复DG
    ORA-17629: Cannot connect to the remote database server
    DG:RFS[8]: No standby redo logfiles created for thread 2
    U盘内容不显示?U盘有文件却看不见?
    【CSS】特殊符号content编码及作为字体图标使用方法
    Python中的if __name__ == '__main__'(转载)
  • 原文地址:https://www.cnblogs.com/korbin/p/5615502.html
Copyright © 2020-2023  润新知